diff --git a/.editorconfig b/.editorconfig index 3d99a200a3..e3a605df50 100644 --- a/.editorconfig +++ b/.editorconfig @@ -7,3 +7,16 @@ insert_final_newline = true max_line_length = 125 trim_trailing_whitespace = true +ktlint_experimental = disabled +ktlint_standard_argument-list-wrapping = disabled +ktlint_standard_block-wrapping = disabled +ktlint_standard_chain-wrapping = disabled +ktlint_standard_function-expression-body = disabled +ktlint_standard_function-signature = disabled +ktlint_standard_import-ordering = disabled +ktlint_standard_multiline-expression-wrapping = disabled +ktlint_standard_parameter-list-wrapping = disabled +ktlint_standard_property-wrapping = disabled +ktlint_standard_trailing-comma-on-call-site = disabled +ktlint_standard_trailing-comma-on-declaration-site = disabled +ktlint_standard_wrapping = disabled diff --git a/.github/workflows/continuous-integration.yml b/.github/workflows/continuous-integration.yml index fd36855aa8..a06546b27a 100644 --- a/.github/workflows/continuous-integration.yml +++ b/.github/workflows/continuous-integration.yml @@ -171,7 +171,7 @@ jobs: run: ./gradlew licensee --no-configuration-cache - name: 'Analyse' - run: ./gradlew detekt ktlintCheck lint -x lintRelease ${{ env.SCAN }} + run: ./gradlew detekt lint -x lintRelease ${{ env.SCAN }} - name: 'Archive analysis reports' uses: actions/upload-artifact@v6 @@ -182,7 +182,7 @@ jobs: **/build/reports/detekt - name: 'Unit tests' - run: ./gradlew :selekt-android:testDebugUnitTest :selekt-java:test :selekt-sqlite3-classes:testJava17 :selekt-sqlite3-classes:testJava25 :selekt-common:test :koverHtmlReport -x integrationTest ${{ env.SCAN }} + run: ./gradlew :selekt-android:testDebugUnitTest :selekt-java:test :selekt-jdbc:test :selekt-sqlite3-classes:testJava17 :selekt-sqlite3-classes:testJava25 :selekt-common:test :koverHtmlReport -x integrationTest ${{ env.SCAN }} - name: 'Archive test reports' uses: actions/upload-artifact@v6 @@ -202,7 +202,7 @@ jobs: - name: 'Build Selekt JVM' run: | - ./gradlew :selekt-jvm:jar ${{ env.SCAN }} + ./gradlew :selekt-jdbc:jar :selekt-jvm:jar ${{ env.SCAN }} - name: 'Verify coverage' run: diff --git a/.github/workflows/publication.yml b/.github/workflows/publication.yml index d23beb13b0..46c0023e76 100644 --- a/.github/workflows/publication.yml +++ b/.github/workflows/publication.yml @@ -210,7 +210,7 @@ jobs: echo "org.gradle.java.home=${JAVA_HOME}" >> gradle.properties - name: 'Unit tests' - run: ./gradlew :selekt-android:testDebugUnitTest :selekt-java:test :selekt-sqlite3-classes:testJava17 :selekt-sqlite3-classes:testJava25 :selekt-common:test -x integrationTest ${{ env.SCAN }} + run: ./gradlew :selekt-android:testDebugUnitTest :selekt-java:test :selekt-jdbc:test :selekt-sqlite3-classes:testJava17 :selekt-sqlite3-classes:testJava25 :selekt-common:test -x integrationTest ${{ env.SCAN }} - name: 'Build Selekt Android' run: | @@ -224,7 +224,7 @@ jobs: - name: 'Build Selekt JVM' run: | - ./gradlew :selekt-jvm:jar ${{ env.SCAN }} + ./gradlew :selekt-jdbc:jar :selekt-jvm:jar ${{ env.SCAN }} - name: 'Publish release to OSSRH' if: github.event_name == 'release' && github.event.action == 'published' diff --git a/AndroidCLI/build.gradle.kts b/AndroidCLI/build.gradle.kts index 4cff7abae8..8003eef40d 100644 --- a/AndroidCLI/build.gradle.kts +++ b/AndroidCLI/build.gradle.kts @@ -18,7 +18,6 @@ plugins { id("com.android.application") id("kotlin-android") alias(libs.plugins.detekt) - alias(libs.plugins.ktlint) } repositories { @@ -31,7 +30,7 @@ android { namespace = "com.bloomberg.selekt.cli" defaultConfig { applicationId = "com.bloomberg.selekt.cli" - minSdk = 21 + minSdk = 24 targetSdk = 34 versionCode = 1 versionName = "0.1" diff --git a/AndroidLibBenchmark/build.gradle.kts b/AndroidLibBenchmark/build.gradle.kts index d71c60b62e..a9410a9822 100644 --- a/AndroidLibBenchmark/build.gradle.kts +++ b/AndroidLibBenchmark/build.gradle.kts @@ -19,7 +19,6 @@ plugins { id("kotlin-android") alias(libs.plugins.androidx.benchmark) alias(libs.plugins.detekt) - alias(libs.plugins.ktlint) } repositories { @@ -31,7 +30,7 @@ android { compileSdkVersion(Versions.ANDROID_SDK.version.toInt()) namespace = "com.bloomberg.selekt.android.benchmark" defaultConfig { - minSdkVersion(21) + minSdkVersion(24) targetSdkVersion(34) testInstrumentationRunner = "androidx.benchmark.junit4.AndroidBenchmarkRunner" testInstrumentationRunnerArguments.putAll(arrayOf( diff --git a/SQLite3/sqlite3_jni.cpp b/SQLite3/sqlite3_jni.cpp index a20a27db7c..976b450918 100644 --- a/SQLite3/sqlite3_jni.cpp +++ b/SQLite3/sqlite3_jni.cpp @@ -151,6 +151,20 @@ Java_com_bloomberg_selekt_ExternalSQLite_bindParameterCount( return sqlite3_bind_parameter_count(statement); } +extern "C" JNIEXPORT jint JNICALL +Java_com_bloomberg_selekt_ExternalSQLite_bindParameterIndex( + JNIEnv* env, + jobject obj, + jlong jstatement, + jstring jname +) { + auto statement = reinterpret_cast(jstatement); + auto name = env->GetStringUTFChars(jname, nullptr); + auto result = sqlite3_bind_parameter_index(statement, name); + env->ReleaseStringUTFChars(jname, name); + return result; +} + extern "C" JNIEXPORT jint JNICALL Java_com_bloomberg_selekt_ExternalSQLite_bindText( JNIEnv* env, diff --git a/build.gradle.kts b/build.gradle.kts index ede205f108..d664cbb961 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -28,9 +28,6 @@ import org.jetbrains.gradle.ext.copyright import org.jetbrains.gradle.ext.settings import org.jetbrains.kotlin.gradle.dsl.JvmTarget import org.jetbrains.kotlin.gradle.tasks.KotlinCompile -import org.jlleitschuh.gradle.ktlint.KtlintExtension -import org.jlleitschuh.gradle.ktlint.reporter.ReporterType -import org.jlleitschuh.gradle.ktlint.tasks.GenerateReportsTask plugins { base @@ -38,7 +35,6 @@ plugins { alias(libs.plugins.kover) alias(libs.plugins.nexus) alias(libs.plugins.detekt) - alias(libs.plugins.ktlint) alias(libs.plugins.ideaExt) alias(libs.plugins.qodana) alias(libs.plugins.ksp) apply false @@ -71,6 +67,7 @@ dependencies { kover(projects.selektApi) kover(projects.selektCommons) kover(projects.selektJava) + kover(projects.selektJdbc) kover(projects.selektJvm) kover(projects.selektSqlite3Classes) } @@ -212,20 +209,6 @@ subprojects { } } -allprojects { - plugins.withId("org.jlleitschuh.gradle.ktlint") { - configure { - disabledRules.set(setOf("import-ordering", "indent", "wrapping")) - reporters { - reporter(ReporterType.HTML) - } - } - } - tasks.withType().configureEach { - reportsOutputDirectory.set(rootProject.layout.buildDirectory.dir("reports/ktlint/${project.name}/$name")) - } -} - koverReport { defaults { filters { @@ -240,7 +223,7 @@ koverReport { verify { rule("Minimal coverage") { bound { - minValue = 96 + minValue = 92 aggregation = AggregationType.COVERED_PERCENTAGE } } diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 2268bdfd4c..fdd6a23b44 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -36,6 +36,7 @@ kotlinx-coroutines-core = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-c mockito-core = { module = "org.mockito:mockito-core", version = "5.21.0" } mockito-kotlin = { module = "org.mockito.kotlin:mockito-kotlin", version = "6.2.1" } robolectric-android-all = { module = "org.robolectric:android-all", version = "12.1-robolectric-8229987" } +slf4j-api = { module = "org.slf4j:slf4j-api", version = "2.0.17" } xerial-sqlite-jdbc = { module = "org.xerial:sqlite-jdbc", version = "3.51.1.0" } [bundles] @@ -49,7 +50,6 @@ ideaExt = { id = "org.jetbrains.gradle.plugin.idea-ext", version = "1.1.7" } jmh = { id = "me.champeau.jmh", version = "0.7.3" } kover = { id = "org.jetbrains.kotlinx.kover", version = "0.7.6" } ksp = { id = "com.google.devtools.ksp", version = "2.3.4" } -ktlint = { id = "org.jlleitschuh.gradle.ktlint", version = "11.5.0" } nexus = { id = "io.github.gradle-nexus.publish-plugin", version = "2.0.0" } qodana = { id = "org.jetbrains.qodana", version = "0.1.12" } undercouch-download = { id = "de.undercouch.download", version = "5.4.0" } diff --git a/selekt-android-lint/build.gradle.kts b/selekt-android-lint/build.gradle.kts index 3eac1f4684..a21f261d28 100644 --- a/selekt-android-lint/build.gradle.kts +++ b/selekt-android-lint/build.gradle.kts @@ -21,7 +21,6 @@ plugins { `maven-publish` signing alias(libs.plugins.detekt) - alias(libs.plugins.ktlint) } repositories { diff --git a/selekt-android-sqlcipher/build.gradle.kts b/selekt-android-sqlcipher/build.gradle.kts index 5715f8c0dd..a7ba79fc7b 100644 --- a/selekt-android-sqlcipher/build.gradle.kts +++ b/selekt-android-sqlcipher/build.gradle.kts @@ -27,7 +27,6 @@ plugins { `maven-publish` signing alias(libs.plugins.detekt) - alias(libs.plugins.ktlint) } repositories { @@ -43,7 +42,7 @@ android { namespace = "com.bloomberg.selekt.android.sqlcipher" ndkVersion = "27.3.13750724" defaultConfig { - minSdk = 21 + minSdk = 24 @Suppress("UnstableApiUsage") externalNativeBuild { cmake { diff --git a/selekt-android/build.gradle.kts b/selekt-android/build.gradle.kts index 98a010a1e6..62678121da 100644 --- a/selekt-android/build.gradle.kts +++ b/selekt-android/build.gradle.kts @@ -29,7 +29,6 @@ plugins { signing alias(libs.plugins.kover) alias(libs.plugins.detekt) - alias(libs.plugins.ktlint) } repositories { @@ -41,7 +40,7 @@ android { compileSdk = Versions.ANDROID_SDK.version.toInt() namespace = "com.bloomberg.selekt.android" defaultConfig { - minSdk = 21 + minSdk = 24 testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner" } buildTypes { diff --git a/selekt-api/build.gradle.kts b/selekt-api/build.gradle.kts index c05e453706..0be35f4002 100644 --- a/selekt-api/build.gradle.kts +++ b/selekt-api/build.gradle.kts @@ -24,7 +24,6 @@ plugins { signing alias(libs.plugins.detekt) alias(libs.plugins.kover) - alias(libs.plugins.ktlint) } repositories { diff --git a/selekt-api/src/main/kotlin/com/bloomberg/selekt/ISQLProgram.kt b/selekt-api/src/main/kotlin/com/bloomberg/selekt/ISQLProgram.kt index 48d82bcce1..a0c1eaa840 100644 --- a/selekt-api/src/main/kotlin/com/bloomberg/selekt/ISQLProgram.kt +++ b/selekt-api/src/main/kotlin/com/bloomberg/selekt/ISQLProgram.kt @@ -18,6 +18,7 @@ package com.bloomberg.selekt import java.io.Closeable +@Suppress("Detekt.ComplexInterface", "Detekt.TooManyFunctions") interface ISQLProgram : Closeable { /** * Bind a byte array value to this statement. The value remains bound until [clearBindings] is called. @@ -27,6 +28,17 @@ interface ISQLProgram : Closeable { */ fun bindBlob(index: Int, value: ByteArray) + /** + * Bind a byte array value to this statement by parameter name. The value remains bound until [clearBindings] is called. + * + * Named parameters in SQLite can use the following syntax: :name, @name, or $name. + * + * @param name The name of the parameter to bind (including the prefix character like :, @, or $). + * @param value The value to bind. + * @throws IllegalArgumentException if the parameter name is not found. + */ + fun bindBlob(name: String, value: ByteArray) + /** * Bind a double value to this statement. The value remains bound until [clearBindings] is called. * @@ -35,6 +47,17 @@ interface ISQLProgram : Closeable { */ fun bindDouble(index: Int, value: Double) + /** + * Bind a double value to this statement by parameter name. The value remains bound until [clearBindings] is called. + * + * Named parameters in SQLite can use the following syntax: :name, @name, or $name. + * + * @param name The name of the parameter to bind (including the prefix character like :, @, or $). + * @param value The value to bind. + * @throws IllegalArgumentException if the parameter name is not found. + */ + fun bindDouble(name: String, value: Double) + /** * Bind an integer value to this statement. The value remains bound until [clearBindings] is called. * @@ -43,6 +66,17 @@ interface ISQLProgram : Closeable { */ fun bindInt(index: Int, value: Int) + /** + * Bind an integer value to this statement by parameter name. The value remains bound until [clearBindings] is called. + * + * Named parameters in SQLite can use the following syntax: :name, @name, or $name. + * + * @param name The name of the parameter to bind (including the prefix character like :, @, or $). + * @param value The value to bind. + * @throws IllegalArgumentException if the parameter name is not found. + */ + fun bindInt(name: String, value: Int) + /** * Bind a long value to this statement. The value remains bound until [clearBindings] is called. * @@ -51,6 +85,17 @@ interface ISQLProgram : Closeable { */ fun bindLong(index: Int, value: Long) + /** + * Bind a long value to this statement by parameter name. The value remains bound until [clearBindings] is called. + * + * Named parameters in SQLite can use the following syntax: :name, @name, or $name. + * + * @param name The name of the parameter to bind (including the prefix character like :, @, or $). + * @param value The value to bind. + * @throws IllegalArgumentException if the parameter name is not found. + */ + fun bindLong(name: String, value: Long) + /** * Bind a null value to this statement. The value remains bound until [clearBindings] is called. * @@ -58,6 +103,16 @@ interface ISQLProgram : Closeable { */ fun bindNull(index: Int) + /** + * Bind a null value to this statement by parameter name. The value remains bound until [clearBindings] is called. + * + * Named parameters in SQLite can use the following syntax: :name, @name, or $name. + * + * @param name The name of the parameter to bind (including the prefix character like :, @, or $). + * @throws IllegalArgumentException if the parameter name is not found. + */ + fun bindNull(name: String) + /** * Bind a String value to this statement. The value remains bound until [clearBindings] is called. * @@ -66,6 +121,17 @@ interface ISQLProgram : Closeable { */ fun bindString(index: Int, value: String) + /** + * Bind a String value to this statement by parameter name. The value remains bound until [clearBindings] is called. + * + * Named parameters in SQLite can use the following syntax: :name, @name, or $name. + * + * @param name The name of the parameter to bind (including the prefix character like :, @, or $). + * @param value The value to bind. + * @throws IllegalArgumentException if the parameter name is not found. + */ + fun bindString(name: String, value: String) + /** * Clears all existing bindings. Unset bindings are treated as null. */ diff --git a/selekt-bom/build.gradle.kts b/selekt-bom/build.gradle.kts index 1778b2d73a..517b270245 100644 --- a/selekt-bom/build.gradle.kts +++ b/selekt-bom/build.gradle.kts @@ -31,6 +31,7 @@ dependencies { selektAndroidSqlcipher, selektApi, selektJava, + selektJdbc, selektSqlite3Classes, selektSqlite3Sqlcipher ) diff --git a/selekt-commons/build.gradle.kts b/selekt-commons/build.gradle.kts index e9267cecee..118ea5a7a9 100644 --- a/selekt-commons/build.gradle.kts +++ b/selekt-commons/build.gradle.kts @@ -24,7 +24,6 @@ plugins { `maven-publish` signing alias(libs.plugins.detekt) - alias(libs.plugins.ktlint) } repositories { diff --git a/selekt-java/build.gradle.kts b/selekt-java/build.gradle.kts index bf455a7cf1..d5ab8b20ab 100644 --- a/selekt-java/build.gradle.kts +++ b/selekt-java/build.gradle.kts @@ -29,7 +29,6 @@ plugins { signing alias(libs.plugins.jmh) alias(libs.plugins.detekt) - alias(libs.plugins.ktlint) } repositories { @@ -76,6 +75,8 @@ dependencies { requireCapability("com.bloomberg.selekt:selekt-sqlite3-classes-java17") } } + jmhImplementation(projects.selektJdbc) + jmhImplementation(projects.selektSqlite3Sqlcipher) jmhImplementation(libs.kotlinx.coroutines.core) jmhImplementation(libs.xerial.sqlite.jdbc) } diff --git a/selekt-java/src/main/kotlin/com/bloomberg/selekt/BatchSQLExecutor.kt b/selekt-java/src/main/kotlin/com/bloomberg/selekt/BatchSQLExecutor.kt index 14a5bcb318..37f7fecbc7 100644 --- a/selekt-java/src/main/kotlin/com/bloomberg/selekt/BatchSQLExecutor.kt +++ b/selekt-java/src/main/kotlin/com/bloomberg/selekt/BatchSQLExecutor.kt @@ -20,11 +20,20 @@ import java.util.stream.Stream import kotlin.streams.asSequence internal interface BatchSQLExecutor { - fun executeBatchForChangedRowCount(sql: String, bindArgs: Sequence>): Int - fun executeBatchForChangedRowCount(sql: String, bindArgs: Iterable>): Int = executeBatchForChangedRowCount(sql, bindArgs.asSequence()) + fun executeBatchForChangedRowCount(sql: String, bindArgs: List>): Int + + fun executeBatchForChangedRowCount(sql: String, bindArgs: Sequence>): Int + + fun executeBatchForChangedRowCount( + sql: String, + bindArgs: Array>, + fromIndex: Int = 0, + toIndex: Int = bindArgs.size + ): Int + fun executeBatchForChangedRowCount(sql: String, bindArgs: Stream>): Int = executeBatchForChangedRowCount(sql, bindArgs.asSequence()) } diff --git a/selekt-java/src/main/kotlin/com/bloomberg/selekt/Databases.kt b/selekt-java/src/main/kotlin/com/bloomberg/selekt/Databases.kt index 1948f9b91d..366da638f4 100644 --- a/selekt-java/src/main/kotlin/com/bloomberg/selekt/Databases.kt +++ b/selekt-java/src/main/kotlin/com/bloomberg/selekt/Databases.kt @@ -68,6 +68,19 @@ class SQLDatabase( SQLStatement.execute(session, sql, bindArgs) } + fun batch(sql: String, bindArgs: List>): Int = transact { + SQLStatement.execute(session, sql, bindArgs) + } + + fun batch( + sql: String, + bindArgs: Array>, + fromIndex: Int = 0, + toIndex: Int = bindArgs.size + ): Int = transact { + SQLStatement.execute(session, sql, bindArgs, fromIndex, toIndex) + } + fun batch(sql: String, bindArgs: Iterable>): Int = transact { SQLStatement.execute(session, sql, bindArgs) } diff --git a/selekt-java/src/main/kotlin/com/bloomberg/selekt/Queries.kt b/selekt-java/src/main/kotlin/com/bloomberg/selekt/Queries.kt index 3b09c183f0..adb2a59993 100644 --- a/selekt-java/src/main/kotlin/com/bloomberg/selekt/Queries.kt +++ b/selekt-java/src/main/kotlin/com/bloomberg/selekt/Queries.kt @@ -32,6 +32,8 @@ internal class SQLQuery internal constructor( private val statementType: SQLStatementType, private val bindArgs: Array ) : IQuery { + private val namedParameters: Map by lazy { parseNamedParameters(sql) } + companion object { fun create( session: ThreadLocalSession, @@ -51,16 +53,30 @@ internal class SQLQuery internal constructor( override fun bindBlob(index: Int, value: ByteArray) = bind(index, value) + override fun bindBlob(name: String, value: ByteArray) { + bind(resolveParameterIndex(name), value) + } + override fun bindDouble(index: Int, value: Double) = bind(index, value) + override fun bindDouble(name: String, value: Double) = bind(resolveParameterIndex(name), value) + override fun bindInt(index: Int, value: Int) = bind(index, value) + override fun bindInt(name: String, value: Int) = bind(resolveParameterIndex(name), value) + override fun bindLong(index: Int, value: Long) = bind(index, value) + override fun bindLong(name: String, value: Long) = bind(resolveParameterIndex(name), value) + override fun bindNull(index: Int) = bind(index, null) + override fun bindNull(name: String) = bind(resolveParameterIndex(name), null) + override fun bindString(index: Int, value: String) = bind(index, value) + override fun bindString(name: String, value: String) = bind(resolveParameterIndex(name), value) + override fun clearBindings() { bindArgs.fill(null) } @@ -101,6 +117,10 @@ internal class SQLQuery internal constructor( private fun bind(index: Int, arg: Any?) { bindArgs[index - 1] = arg } + + private fun resolveParameterIndex(name: String): Int = namedParameters[name] ?: throw IllegalArgumentException( + "Named parameter '$name' not found in SQL. Available parameters: ${namedParameters.keys}" + ) } class SimpleSQLQuery( diff --git a/selekt-java/src/main/kotlin/com/bloomberg/selekt/SQLConnection.kt b/selekt-java/src/main/kotlin/com/bloomberg/selekt/SQLConnection.kt index f406677a20..190b1e5951 100644 --- a/selekt-java/src/main/kotlin/com/bloomberg/selekt/SQLConnection.kt +++ b/selekt-java/src/main/kotlin/com/bloomberg/selekt/SQLConnection.kt @@ -111,6 +111,38 @@ internal class SQLConnection( sqlite.totalChanges(pointer) - changes } + override fun executeBatchForChangedRowCount( + sql: String, + bindArgs: List> + ) = withPreparedStatement(sql) { + val changes = sqlite.totalChanges(pointer) + for (i in bindArgs.indices) { + reset() + bindArguments(bindArgs[i]) + if (SQL_DONE != step()) { + return@withPreparedStatement -1 + } + } + sqlite.totalChanges(pointer) - changes + } + + override fun executeBatchForChangedRowCount( + sql: String, + bindArgs: Array>, + fromIndex: Int, + toIndex: Int + ) = withPreparedStatement(sql) { + val changes = sqlite.totalChanges(pointer) + for (i in fromIndex until toIndex) { + reset() + bindArguments(bindArgs[i]) + if (SQL_DONE != step()) { + return@withPreparedStatement -1 + } + } + sqlite.totalChanges(pointer) - changes + } + override fun executeForCursorWindow( sql: String, bindArgs: Array, diff --git a/selekt-java/src/main/kotlin/com/bloomberg/selekt/SQLParameterParser.kt b/selekt-java/src/main/kotlin/com/bloomberg/selekt/SQLParameterParser.kt new file mode 100644 index 0000000000..b9a93306f9 --- /dev/null +++ b/selekt-java/src/main/kotlin/com/bloomberg/selekt/SQLParameterParser.kt @@ -0,0 +1,123 @@ +/* + * Copyright 2026 Bloomberg Finance L.P. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.bloomberg.selekt + +/** + * Parses SQL to extract named parameter positions. + * + * SQLite supports these named parameter forms: + * - :name (colon prefix) + * - @name (at sign prefix) + * - $name (dollar sign prefix) + * + * Anonymous parameters (?) are assigned positions in order of appearance. Named parameters are also assigned + * positions in order of appearance. + * + * This parser handles: + * - String literals (single quotes), parameters inside are ignored + * - Identifiers (double quotes or backticks), parameters inside are ignored + * - Comments (-- line comments and block comments), parameters inside are ignored + */ +@Suppress("Detekt.CognitiveComplexMethod") +internal fun parseNamedParameters(sql: String): Map { + val result = mutableMapOf() + var parameterIndex = 0 + var i = 0 + while (i < sql.length) { + when (sql[i]) { + '\'' -> i = skipStringLiteral(sql, i, '\'') + '"' -> i = skipStringLiteral(sql, i, '"') + '`' -> i = skipStringLiteral(sql, i, '`') + '[' -> i = skipBracketIdentifier(sql, i) + '-' if i + 1 < sql.length && sql[i + 1] == '-' -> i = skipLineComment(sql, i) + '/' if i + 1 < sql.length && sql[i + 1] == '*' -> i = skipBlockComment(sql, i) + '?' -> { + ++parameterIndex + if (i + 1 < sql.length && sql[i + 1].isDigit()) { + i = skipDigits(sql, i + 1) + } else { + i++ + } + } + ':', '@', '$' -> { + ++parameterIndex + val startIndex = i++ + while (i < sql.length && sql[i].isParameterNameChar()) { + i++ + } + if (i - startIndex > 1) { + result.putIfAbsent(sql.substring(startIndex, i), parameterIndex) + } + } + else -> i++ + } + } + return result +} + +private fun Char.isParameterNameChar(): Boolean = isLetterOrDigit() || this == '_' + +private fun skipStringLiteral(sql: String, start: Int, quote: Char): Int { + var i = start + 1 + while (i < sql.length) { + if (sql[i] == quote) { + if (i + 1 < sql.length && sql[i + 1] == quote) { + i += 2 + } else { + return i + 1 + } + } else { + i++ + } + } + return i +} + +private fun skipBracketIdentifier(sql: String, start: Int): Int { + var i = start + 1 + while (i < sql.length && sql[i] != ']') { + i++ + } + return if (i < sql.length) i + 1 else i +} + +private fun skipLineComment(sql: String, start: Int): Int { + var i = start + 2 + while (i < sql.length && sql[i] != '\n') { + i++ + } + return if (i < sql.length) i + 1 else i +} + +private fun skipBlockComment(sql: String, start: Int): Int { + var i = start + 2 + while (i + 1 < sql.length) { + if (sql[i] == '*' && sql[i + 1] == '/') { + return i + 2 + } + i++ + } + return sql.length +} + +private fun skipDigits(sql: String, start: Int): Int { + var i = start + while (i < sql.length && sql[i].isDigit()) { + i++ + } + return i +} diff --git a/selekt-java/src/main/kotlin/com/bloomberg/selekt/SQLPreparedStatement.kt b/selekt-java/src/main/kotlin/com/bloomberg/selekt/SQLPreparedStatement.kt index 8c9dd83f1f..e26213d67a 100644 --- a/selekt-java/src/main/kotlin/com/bloomberg/selekt/SQLPreparedStatement.kt +++ b/selekt-java/src/main/kotlin/com/bloomberg/selekt/SQLPreparedStatement.kt @@ -24,7 +24,7 @@ private const val NANOS_PER_MILLI = 1_000_000L private const val MAX_PAUSE_MILLIS = 100L @NotThreadSafe -@Suppress("Detekt.TooManyFunctions") +@Suppress("Detekt.MethodOverloading", "Detekt.TooManyFunctions") internal class SQLPreparedStatement( private var pointer: Pointer, private var rawSql: String, @@ -73,26 +73,50 @@ internal class SQLPreparedStatement( sqlite.bindBlob(pointer, index, value) } + fun bind(name: String, value: ByteArray) { + sqlite.bindBlob(pointer, resolveParameterIndex(name), value) + } + fun bind(index: Int, value: Double) { sqlite.bindDouble(pointer, index, value) } + fun bind(name: String, value: Double) { + sqlite.bindDouble(pointer, resolveParameterIndex(name), value) + } + fun bind(index: Int, value: Int) { sqlite.bindInt(pointer, index, value) } + fun bind(name: String, value: Int) { + sqlite.bindInt(pointer, resolveParameterIndex(name), value) + } + fun bind(index: Int, value: Long) { sqlite.bindInt64(pointer, index, value) } + fun bind(name: String, value: Long) { + sqlite.bindInt64(pointer, resolveParameterIndex(name), value) + } + fun bind(index: Int, value: String) { sqlite.bindText(pointer, index, value) } + fun bind(name: String, value: String) { + sqlite.bindText(pointer, resolveParameterIndex(name), value) + } + fun bindNull(index: Int) { sqlite.bindNull(pointer, index) } + fun bindNull(name: String) { + sqlite.bindNull(pointer, resolveParameterIndex(name)) + } + fun bindZeroBlob(index: Int, length: Int) { sqlite.bindZeroBlob(pointer, index, length) } @@ -159,4 +183,8 @@ internal class SQLPreparedStatement( } private fun Long.nextRandom() = random.nextLong(this) + + private fun resolveParameterIndex(name: String): Int = sqlite.bindParameterIndex(pointer, name).also { + require(it > 0) { "Named parameter '$name' not found in SQL statement." } + } } diff --git a/selekt-java/src/main/kotlin/com/bloomberg/selekt/SQLStatement.kt b/selekt-java/src/main/kotlin/com/bloomberg/selekt/SQLStatement.kt index 3102e8f945..82ac7ef63b 100644 --- a/selekt-java/src/main/kotlin/com/bloomberg/selekt/SQLStatement.kt +++ b/selekt-java/src/main/kotlin/com/bloomberg/selekt/SQLStatement.kt @@ -47,6 +47,7 @@ internal enum class SQLStatementType( } private const val SUFFICIENT_SQL_PREFIX_LENGTH = 3 +private const val ONLY_BATCH_UPDATES = "Only batched updates are permitted." @JvmSynthetic @Suppress("Detekt.CognitiveComplexMethod", "Detekt.ComplexCondition", "Detekt.MagicNumber", "Detekt.NestedBlockDepth") @@ -97,6 +98,7 @@ internal fun String.resolvedSqlStatementType(): SQLStatementType = trimStartByIn } } +@Suppress("Detekt.MethodOverloading") @ThreadSafe internal class SQLStatement private constructor( private val session: ThreadLocalSession, @@ -105,6 +107,9 @@ internal class SQLStatement private constructor( private val args: Array, private val asWrite: Boolean ) : ISQLStatement { + private val namedParameters: Map by lazy { parseNamedParameters(sql) } + + @Suppress("Detekt.TooManyFunctions") companion object { fun execute( session: ThreadLocalSession, @@ -120,14 +125,36 @@ internal class SQLStatement private constructor( sql: String, bindArgs: Sequence> ): Int { - require(SQLStatementType.UPDATE === sql.resolvedSqlStatementType()) { - "Only batched updates are permitted." + require(SQLStatementType.UPDATE === sql.resolvedSqlStatementType()) { ONLY_BATCH_UPDATES } + return session.get().execute(true, sql) { + it.executeBatchForChangedRowCount(sql, bindArgs) } + } + + fun execute( + session: ThreadLocalSession, + sql: String, + bindArgs: List> + ): Int { + require(SQLStatementType.UPDATE === sql.resolvedSqlStatementType()) { ONLY_BATCH_UPDATES } return session.get().execute(true, sql) { it.executeBatchForChangedRowCount(sql, bindArgs) } } + fun execute( + session: ThreadLocalSession, + sql: String, + bindArgs: Array>, + fromIndex: Int = 0, + toIndex: Int = bindArgs.size + ): Int { + require(SQLStatementType.UPDATE === sql.resolvedSqlStatementType()) { ONLY_BATCH_UPDATES } + return session.get().execute(true, sql) { + it.executeBatchForChangedRowCount(sql, bindArgs, fromIndex, toIndex) + } + } + fun execute( session: ThreadLocalSession, sql: String, @@ -199,26 +226,50 @@ internal class SQLStatement private constructor( bind(index, value) } + override fun bindBlob(name: String, value: ByteArray) { + bind(resolveParameterIndex(name), value) + } + override fun bindDouble(index: Int, value: Double) { bind(index, value) } + override fun bindDouble(name: String, value: Double) { + bind(resolveParameterIndex(name), value) + } + override fun bindInt(index: Int, value: Int) { bind(index, value) } + override fun bindInt(name: String, value: Int) { + bind(resolveParameterIndex(name), value) + } + override fun bindLong(index: Int, value: Long) { bind(index, value) } + override fun bindLong(name: String, value: Long) { + bind(resolveParameterIndex(name), value) + } + override fun bindNull(index: Int) { bind(index, null) } + override fun bindNull(name: String) { + bind(resolveParameterIndex(name), null) + } + override fun bindString(index: Int, value: String) { bind(index, value) } + override fun bindString(name: String, value: String) { + bind(resolveParameterIndex(name), value) + } + override fun clearBindings() { args.fill(null) } @@ -250,6 +301,11 @@ internal class SQLStatement private constructor( private fun bind(index: Int, value: Any?) { args[index - 1] = value } + + private fun resolveParameterIndex(name: String): Int = + namedParameters[name] ?: throw IllegalArgumentException( + "Named parameter '$name' not found in SQL. Available parameters: ${namedParameters.keys}" + ) } @Suppress("UseDataClass") diff --git a/selekt-java/src/main/kotlin/com/bloomberg/selekt/SQLite.kt b/selekt-java/src/main/kotlin/com/bloomberg/selekt/SQLite.kt index c92a53c56c..1fcf4a02ad 100644 --- a/selekt-java/src/main/kotlin/com/bloomberg/selekt/SQLite.kt +++ b/selekt-java/src/main/kotlin/com/bloomberg/selekt/SQLite.kt @@ -49,6 +49,8 @@ open class SQLite( fun bindParameterCount(statement: Long) = sqlite.bindParameterCount(statement) + fun bindParameterIndex(statement: Long, name: String) = sqlite.bindParameterIndex(statement, name) + fun bindText(statement: Long, index: Int, value: String) = checkBindSQLCode( statement, sqlite.bindText(statement, index, value) diff --git a/selekt-java/src/test/kotlin/com/bloomberg/selekt/SQLParameterParserTest.kt b/selekt-java/src/test/kotlin/com/bloomberg/selekt/SQLParameterParserTest.kt new file mode 100644 index 0000000000..7f7bcc96c8 --- /dev/null +++ b/selekt-java/src/test/kotlin/com/bloomberg/selekt/SQLParameterParserTest.kt @@ -0,0 +1,158 @@ +/* + * Copyright 2026 Bloomberg Finance L.P. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.bloomberg.selekt + +import org.junit.jupiter.api.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +internal class SQLParameterParserTest { + @Test + fun parseSimpleColonParameter() { + assertEquals( + mapOf(":name" to 1), + parseNamedParameters("SELECT * FROM users WHERE name = :name") + ) + } + + @Test + fun parseSimpleAtParameter() { + assertEquals( + mapOf("@name" to 1), + parseNamedParameters("SELECT * FROM users WHERE name = @name") + ) + } + + @Test + fun parseSimpleDollarParameter() { + assertEquals(mapOf("\$name" to 1), parseNamedParameters( + "SELECT * FROM users WHERE name = \$name" + )) + } + + @Test + fun parseMultipleParameters() { + assertEquals( + mapOf(":name" to 1, ":minAge" to 2, "@city" to 3), + parseNamedParameters("SELECT * FROM users WHERE name = :name AND age > :minAge AND city = @city") + ) + } + + @Test + fun parseMixedPositionalAndNamedParameters() { + assertEquals(mapOf(":name" to 2), parseNamedParameters( + "SELECT * FROM users WHERE id = ? AND name = :name AND age > ?" + )) + } + + @Test + fun parseParameterInsideStringLiteralIgnored() { + assertEquals( + mapOf(":age" to 1), + parseNamedParameters("SELECT * FROM users WHERE name = ':notAParam' AND age = :age") + ) + } + + @Test + fun parseParameterInsideDoubleQuotedIdentifierIgnored() { + assertEquals( + mapOf(":age" to 1), + parseNamedParameters("SELECT * FROM \":notATable\" WHERE age = :age") + ) + } + + @Test + fun parseParameterInsideBacktickIdentifierIgnored() { + assertEquals( + mapOf(":age" to 1), + parseNamedParameters("SELECT * FROM `:notATable` WHERE age = :age") + ) + } + + @Test + fun parseParameterInsideBracketIdentifierIgnored() { + assertEquals( + mapOf(":age" to 1), + parseNamedParameters("SELECT * FROM [:notATable] WHERE age = :age") + ) + } + + @Test + fun parseParameterInsideLineCommentIgnored() { + assertEquals( + mapOf(":age" to 1), + parseNamedParameters("SELECT * FROM users -- WHERE name = :notAParam\nWHERE age = :age") + ) + } + + @Test + fun parseParameterInsideBlockCommentIgnored() { + assertEquals( + mapOf(":age" to 1), + parseNamedParameters("SELECT * FROM users /* WHERE name = :notAParam */ WHERE age = :age") + ) + } + + @Test + fun parseDuplicateParameterReturnsFirstIndex() { + assertEquals( + mapOf(":name" to 1), + parseNamedParameters("SELECT * FROM users WHERE name = :name OR alias = :name") + ) + } + + @Test + fun parseParameterWithUnderscores() { + assertEquals( + mapOf(":first_name" to 1), + parseNamedParameters("SELECT * FROM users WHERE first_name = :first_name") + ) + } + + @Test + fun parseParameterWithNumbers() { + assertEquals( + mapOf(":id1" to 1, ":code2" to 2), + parseNamedParameters("SELECT * FROM users WHERE id = :id1 AND code = :code2") + ) + } + + @Test + fun parseNumberedQuestionMark() { + assertEquals( + mapOf(":name" to 2), + parseNamedParameters("SELECT * FROM users WHERE id = ?1 AND name = :name AND code = ?2") + ) + } + + @Test + fun parseEmptySql() { + assertTrue(parseNamedParameters("").isEmpty()) + } + + @Test + fun parseSqlWithNoParameters() { + assertTrue(parseNamedParameters("SELECT * FROM users").isEmpty()) + } + + @Test + fun parseEscapedQuoteInStringLiteral() { + assertEquals(mapOf(":age" to 1), parseNamedParameters( + "SELECT * FROM users WHERE name = 'O''Brien' AND age = :age" + )) + } +} diff --git a/selekt-java/src/test/kotlin/com/bloomberg/selekt/SQLPreparedStatementTest.kt b/selekt-java/src/test/kotlin/com/bloomberg/selekt/SQLPreparedStatementTest.kt index f551feebb6..029b398616 100644 --- a/selekt-java/src/test/kotlin/com/bloomberg/selekt/SQLPreparedStatementTest.kt +++ b/selekt-java/src/test/kotlin/com/bloomberg/selekt/SQLPreparedStatementTest.kt @@ -38,16 +38,15 @@ private const val INTERVAL_MILLIS = 2_000L internal class SQLPreparedStatementTest { @Test - fun clearBindings() { - val sqlite = mock() - SQLPreparedStatement(POINTER, "SELECT * FROM Foo", sqlite, CommonThreadLocalRandom).clearBindings() - verify(sqlite, times(1)).clearBindings(eq(POINTER)) + fun clearBindings(): Unit = mock().run { + SQLPreparedStatement(POINTER, "SELECT * FROM Foo", this, CommonThreadLocalRandom).clearBindings() + verify(this, times(1)).clearBindings(eq(POINTER)) } @Test fun stepWithRetryDone() { - val sqlite = mock().apply { - whenever(stepWithoutThrowing(any())) doReturn SQL_DONE + val sqlite = mock { + whenever(it.stepWithoutThrowing(any())) doReturn SQL_DONE } val statement = SQLPreparedStatement(POINTER, "BEGIN IMMEDIATE TRANSACTION", sqlite, CommonThreadLocalRandom) assertEquals(SQL_DONE, statement.step(INTERVAL_MILLIS)) @@ -55,8 +54,8 @@ internal class SQLPreparedStatementTest { @Test fun stepWithRetryRow() { - val sqlite = mock().apply { - whenever(stepWithoutThrowing(any())) doReturn SQL_ROW + val sqlite = mock { + whenever(it.stepWithoutThrowing(any())) doReturn SQL_ROW } val statement = SQLPreparedStatement(POINTER, "SELECT * FROM Foo", sqlite, CommonThreadLocalRandom) assertEquals(SQL_ROW, statement.step(INTERVAL_MILLIS)) @@ -64,9 +63,9 @@ internal class SQLPreparedStatementTest { @Test fun stepWithRetryExpires() { - val sqlite = mock().apply { - whenever(databaseHandle(any())) doReturn DB - whenever(step(any())) doReturn SQL_BUSY + val sqlite = mock { + whenever(it.databaseHandle(any())) doReturn DB + whenever(it.step(any())) doReturn SQL_BUSY } val statement = SQLPreparedStatement(POINTER, "BEGIN BLAH", sqlite, CommonThreadLocalRandom) assertFailsWith { @@ -76,8 +75,8 @@ internal class SQLPreparedStatementTest { @Test fun stepWithRetryCanUltimatelySucceed() { - val sqlite = mock().apply { - whenever(stepWithoutThrowing(any())) doAnswer object : Answer { + val sqlite = mock { + whenever(it.stepWithoutThrowing(any())) doAnswer object : Answer { private var count = 0 override fun answer(invocation: InvocationOnMock) = when (count++) { @@ -92,9 +91,9 @@ internal class SQLPreparedStatementTest { @Test fun stepRetryDoesNotStackOverflow() { - val sqlite = mock().apply { - whenever(databaseHandle(any())) doReturn DB - whenever(stepWithoutThrowing(any())) doReturn SQL_BUSY + val sqlite = mock { + whenever(it.databaseHandle(any())) doReturn DB + whenever(it.stepWithoutThrowing(any())) doReturn SQL_BUSY } val statement = SQLPreparedStatement(POINTER, "BEGIN BLAH", sqlite, CommonThreadLocalRandom) assertFailsWith { @@ -112,29 +111,112 @@ internal class SQLPreparedStatementTest { @Test fun isBusyTrue() { - val sqlite = mock().apply { - whenever(databaseHandle(any())) doReturn DB - whenever(statementBusy(any())) doReturn 1 + val sqlite = mock { + whenever(it.databaseHandle(any())) doReturn DB + whenever(it.statementBusy(any())) doReturn 1 } assertTrue(SQLPreparedStatement(POINTER, "BEGIN BLAH", sqlite, CommonThreadLocalRandom).isBusy()) } @Test fun isBusyFalse() { - val sqlite = mock().apply { - whenever(databaseHandle(any())) doReturn DB - whenever(statementBusy(any())) doReturn 0 + val sqlite = mock { + whenever(it.databaseHandle(any())) doReturn DB + whenever(it.statementBusy(any())) doReturn 0 } assertFalse(SQLPreparedStatement(POINTER, "BEGIN BLAH", sqlite, CommonThreadLocalRandom).isBusy()) } @Test fun columnName() { - val sqlite = mock().apply { - whenever(databaseHandle(any())) doReturn DB - whenever(columnName(any(), any())) doReturn "foo" + val sqlite = mock { + whenever(it.databaseHandle(any())) doReturn DB + whenever(it.columnName(any(), any())) doReturn "foo" } assertEquals("foo", SQLPreparedStatement(POINTER, "BEGIN BLAH", sqlite, CommonThreadLocalRandom).columnName(0)) verify(sqlite, times(1)).columnName(eq(POINTER), eq(0)) } + + @Test + fun bindBlobByName() { + val blob = byteArrayOf(1, 2, 3) + val sqlite = mock { + whenever(it.bindParameterIndex(any(), eq(":data"))) doReturn 1 + } + SQLPreparedStatement(POINTER, "INSERT INTO t VALUES (:data)", sqlite, CommonThreadLocalRandom) + .bind(":data", blob) + verify(sqlite, times(1)).bindParameterIndex(eq(POINTER), eq(":data")) + verify(sqlite, times(1)).bindBlob(eq(POINTER), eq(1), eq(blob)) + } + + @Test + fun bindDoubleByName() { + val sqlite = mock { + whenever(it.bindParameterIndex(any(), eq(":value"))) doReturn 1 + } + SQLPreparedStatement(POINTER, "INSERT INTO t VALUES (:value)", sqlite, CommonThreadLocalRandom) + .bind(":value", 3.14) + verify(sqlite, times(1)).bindParameterIndex(eq(POINTER), eq(":value")) + verify(sqlite, times(1)).bindDouble(eq(POINTER), eq(1), eq(3.14)) + } + + @Test + fun bindIntByName() { + val sqlite = mock { + whenever(it.bindParameterIndex(any(), eq("@count"))) doReturn 2 + } + SQLPreparedStatement(POINTER, "INSERT INTO t VALUES (?, @count)", sqlite, CommonThreadLocalRandom) + .bind("@count", 42) + verify(sqlite, times(1)).bindParameterIndex(eq(POINTER), eq("@count")) + verify(sqlite, times(1)).bindInt(eq(POINTER), eq(2), eq(42)) + } + + @Test + fun bindLongByName() { + val sqlite = mock { + whenever(it.bindParameterIndex(any(), eq($$"$id"))) doReturn 1 + } + SQLPreparedStatement(POINTER, $$"SELECT * FROM t WHERE id = $id", sqlite, CommonThreadLocalRandom) + .bind($$"$id", 123_456_789L) + verify(sqlite, times(1)).bindParameterIndex(eq(POINTER), eq($$"$id")) + verify(sqlite, times(1)).bindInt64(eq(POINTER), eq(1), eq(123_456_789L)) + } + + @Test + fun bindStringByName() { + val sqlite = mock { + whenever(it.bindParameterIndex(any(), eq(":name"))) doReturn 1 + } + SQLPreparedStatement(POINTER, "INSERT INTO t VALUES (:name)", sqlite, CommonThreadLocalRandom) + .bind(":name", "test") + verify(sqlite, times(1)).bindParameterIndex(eq(POINTER), eq(":name")) + verify(sqlite, times(1)).bindText(eq(POINTER), eq(1), eq("test")) + } + + @Test + fun bindNullByName() { + val sqlite = mock { + whenever(it.bindParameterIndex(any(), eq(":nullable"))) doReturn 1 + } + SQLPreparedStatement(POINTER, "INSERT INTO t VALUES (:nullable)", sqlite, CommonThreadLocalRandom) + .bindNull(":nullable") + verify(sqlite, times(1)).bindParameterIndex(eq(POINTER), eq(":nullable")) + verify(sqlite, times(1)).bindNull(eq(POINTER), eq(1)) + } + + @Test + fun bindByNameThrowsForUnknownParameter() { + val sqlite = mock { + whenever(it.bindParameterIndex(any(), eq(":unknown"))) doReturn 0 + } + val statement = SQLPreparedStatement( + POINTER, + "INSERT INTO t VALUES (:known)", + sqlite, + CommonThreadLocalRandom + ) + assertFailsWith { + statement.bind(":unknown", "value") + } + } } diff --git a/selekt-java/src/test/kotlin/com/bloomberg/selekt/pools/SingleObjectPoolTest.kt b/selekt-java/src/test/kotlin/com/bloomberg/selekt/pools/SingleObjectPoolTest.kt index 6a2d544f34..dede251bba 100644 --- a/selekt-java/src/test/kotlin/com/bloomberg/selekt/pools/SingleObjectPoolTest.kt +++ b/selekt-java/src/test/kotlin/com/bloomberg/selekt/pools/SingleObjectPoolTest.kt @@ -401,7 +401,7 @@ internal class SingleObjectPoolTest { } @Test - fun interruptBorrowerThenReturn(): Unit = pool.run { + fun interruptBorrowerdoReturn(): Unit = pool.run { borrowObject().let { Thread.interrupted() assertDoesNotThrow { diff --git a/selekt-jdbc/build.gradle.kts b/selekt-jdbc/build.gradle.kts new file mode 100644 index 0000000000..3d72238e9c --- /dev/null +++ b/selekt-jdbc/build.gradle.kts @@ -0,0 +1,79 @@ +/* + * Copyright 2026 Bloomberg Finance L.P. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +@file:Suppress("UnstableApiUsage") + +description = "Selekt JDBC library." + +plugins { + kotlin("jvm") + id("com.android.lint") + alias(libs.plugins.kover) + alias(libs.plugins.dokka) + `maven-publish` + signing + alias(libs.plugins.detekt) +} + +repositories { + mavenCentral() + google() +} + +disableKotlinCompilerAssertions() + +java { + toolchain { + languageVersion.set(JavaLanguageVersion.of(25)) + } + withJavadocJar() + withSourcesJar() +} + +dependencies { + implementation(projects.selektApi) + implementation(projects.selektJava) + implementation(projects.selektSqlite3Api) + implementation(projects.selektSqlite3Classes) { + capabilities { + requireCapability("com.bloomberg.selekt:selekt-sqlite3-classes-java25") + } + } + implementation(libs.slf4j.api) +} + +tasks.register("buildHostSQLite") { + dependsOn(":SQLite3:buildHost", "copyJniLibs") +} + +tasks.register("copyJniLibs") { + from(fileTree(project(":SQLite3").layout.buildDirectory.dir("intermediates/libs"))) + into(layout.buildDirectory.dir("intermediates/libs/jni")) + mustRunAfter(":SQLite3:buildHost") +} + +tasks.withType().configureEach { + dependsOn("buildHostSQLite") +} + +publishing { + publications.register("main") { + from(components.getByName("java")) + pom { + commonInitialisation(project) + } + } +} diff --git a/selekt-jdbc/src/main/kotlin/com/bloomberg/selekt/jdbc/connection/JdbcConnection.kt b/selekt-jdbc/src/main/kotlin/com/bloomberg/selekt/jdbc/connection/JdbcConnection.kt new file mode 100644 index 0000000000..7002865785 --- /dev/null +++ b/selekt-jdbc/src/main/kotlin/com/bloomberg/selekt/jdbc/connection/JdbcConnection.kt @@ -0,0 +1,486 @@ +/* + * Copyright 2026 Bloomberg Finance L.P. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.bloomberg.selekt.jdbc.connection + +import com.bloomberg.selekt.SQLDatabase +import com.bloomberg.selekt.jdbc.exception.SQLExceptionMapper +import com.bloomberg.selekt.jdbc.lob.JdbcClob +import com.bloomberg.selekt.jdbc.metadata.JdbcDatabaseMetaData +import com.bloomberg.selekt.jdbc.statement.JdbcPreparedStatement +import com.bloomberg.selekt.jdbc.statement.JdbcStatement +import com.bloomberg.selekt.jdbc.util.ConnectionURL +import java.sql.Blob +import java.sql.CallableStatement +import java.sql.Clob +import java.sql.Connection +import java.sql.DatabaseMetaData +import java.sql.NClob +import java.sql.PreparedStatement +import java.sql.ResultSet +import java.sql.SQLException +import java.sql.SQLFeatureNotSupportedException +import java.sql.SQLWarning +import java.sql.SQLXML +import java.sql.Savepoint +import java.sql.Statement +import java.sql.Struct +import java.lang.invoke.MethodHandles +import java.util.Properties +import java.util.concurrent.Executor +import javax.annotation.concurrent.NotThreadSafe +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +@Suppress("MethodOverloading", "TooGenericExceptionCaught", "Detekt.StringLiteralDuplication") +@NotThreadSafe +internal class JdbcConnection( + private val database: SQLDatabase, + private val connectionURL: ConnectionURL, + private val properties: Properties +) : Connection { + companion object { + private val logger: Logger = LoggerFactory.getLogger(JdbcConnection::class.java) + + private val CLOSED = MethodHandles.lookup().findVarHandle( + JdbcConnection::class.java, + "closed", + Boolean::class.javaPrimitiveType + ) + } + + @Volatile + private var closed = false + private var autoCommit = true + private var readOnly = false + private var transactionIsolation = Connection.TRANSACTION_SERIALIZABLE + private var networkTimeout = 0 + private val holdability = ResultSet.CLOSE_CURSORS_AT_COMMIT + private val warnings = mutableListOf() + + private val _metaData by lazy { JdbcDatabaseMetaData(this, database, connectionURL) } + + init { + applyConnectionProperties() + } + + override fun createStatement(): Statement = createStatement( + ResultSet.TYPE_FORWARD_ONLY, + ResultSet.CONCUR_READ_ONLY, + holdability + ) + + override fun createStatement(resultSetType: Int, resultSetConcurrency: Int): Statement = createStatement( + resultSetType, + resultSetConcurrency, + holdability + ) + + override fun createStatement( + resultSetType: Int, + resultSetConcurrency: Int, + resultSetHoldability: Int + ): Statement { + checkClosed() + if (resultSetType != ResultSet.TYPE_FORWARD_ONLY) { + throw SQLException("SQLite only supports TYPE_FORWARD_ONLY result sets") + } + if (resultSetConcurrency != ResultSet.CONCUR_READ_ONLY) { + throw SQLException("SQLite only supports CONCUR_READ_ONLY concurrency") + } + return JdbcStatement(this, database, resultSetType, resultSetConcurrency, resultSetHoldability) + } + + override fun prepareStatement(sql: String): PreparedStatement = prepareStatement( + sql, + ResultSet.TYPE_FORWARD_ONLY, + ResultSet.CONCUR_READ_ONLY, + holdability + ) + + override fun prepareStatement( + sql: String, + resultSetType: Int, + resultSetConcurrency: Int + ): PreparedStatement = prepareStatement(sql, resultSetType, resultSetConcurrency, holdability) + + override fun prepareStatement( + sql: String, + resultSetType: Int, + resultSetConcurrency: Int, + resultSetHoldability: Int + ): PreparedStatement { + checkClosed() + if (resultSetType != ResultSet.TYPE_FORWARD_ONLY) { + throw SQLException("SQLite only supports TYPE_FORWARD_ONLY result sets") + } + if (resultSetConcurrency != ResultSet.CONCUR_READ_ONLY) { + throw SQLException("SQLite only supports CONCUR_READ_ONLY concurrency") + } + return JdbcPreparedStatement(this, database, sql, resultSetType, resultSetConcurrency, resultSetHoldability) + } + + override fun prepareStatement( + sql: String, + autoGeneratedKeys: Int + ): PreparedStatement { + checkClosed() + return JdbcPreparedStatement( + this, + database, + sql, + ResultSet.TYPE_FORWARD_ONLY, + ResultSet.CONCUR_READ_ONLY, + holdability, + ) + } + + override fun prepareStatement( + sql: String, + columnIndexes: IntArray + ): PreparedStatement { + checkClosed() + return JdbcPreparedStatement( + this, + database, + sql, + ResultSet.TYPE_FORWARD_ONLY, + ResultSet.CONCUR_READ_ONLY, + holdability, + ) + } + + override fun prepareStatement( + sql: String, + columnNames: Array + ): PreparedStatement { + checkClosed() + return JdbcPreparedStatement( + this, + database, + sql, + ResultSet.TYPE_FORWARD_ONLY, + ResultSet.CONCUR_READ_ONLY, + holdability, + ) + } + + override fun prepareCall(sql: String): CallableStatement { + checkClosed() + throw SQLException("SQLite does not support stored procedures") + } + + override fun prepareCall( + sql: String, + resultSetType: Int, + resultSetConcurrency: Int + ): CallableStatement { + checkClosed() + throw SQLException("SQLite does not support stored procedures") + } + + override fun prepareCall( + sql: String, + resultSetType: Int, + resultSetConcurrency: Int, + resultSetHoldability: Int + ): CallableStatement { + checkClosed() + throw SQLException("SQLite does not support stored procedures") + } + + override fun nativeSQL(sql: String): String { + checkClosed() + return sql + } + + override fun setAutoCommit(autoCommit: Boolean) { + checkClosed() + if (this.autoCommit == autoCommit) { + return + } + database.runCatching { + if (autoCommit && inTransaction) { + setTransactionSuccessful() + endTransaction() + } + this@JdbcConnection.autoCommit = autoCommit + }.onFailure { e -> + throw SQLExceptionMapper.mapException(e as? SQLException ?: SQLException(e.message, e)) + } + } + + override fun getAutoCommit(): Boolean = autoCommit + + override fun commit() { + checkClosed() + if (autoCommit) { + throw SQLException("Cannot call commit() while in auto-commit mode") + } + database.runCatching { + if (inTransaction) { + setTransactionSuccessful() + endTransaction() + } + }.onFailure { e -> + throw SQLExceptionMapper.mapException(e as? SQLException ?: SQLException(e.message, e)) + } + } + + override fun rollback() { + checkClosed() + if (autoCommit) { + throw SQLException("Cannot call rollback() while in auto-commit mode") + } + database.runCatching { + if (inTransaction) { + endTransaction() + } + }.onFailure { e -> + throw SQLExceptionMapper.mapException(e as? SQLException ?: SQLException(e.message, e)) + } + } + + override fun rollback(savepoint: Savepoint?) { + checkClosed() + if (autoCommit) { + throw SQLException("Cannot call rollback() while in auto-commit mode") + } else if (savepoint == null) { + rollback() + return + } + runCatching { + database.exec("ROLLBACK TO SAVEPOINT ${savepoint.savepointName}") + }.onFailure { e -> + throw SQLExceptionMapper.mapException(e as? SQLException ?: SQLException(e.message, e)) + } + } + + override fun setSavepoint(): Savepoint = setSavepoint(null) + + override fun setSavepoint(name: String?): Savepoint { + checkClosed() + if (autoCommit) { + throw SQLException("Cannot create savepoint while in auto-commit mode") + } + val savepointName = name ?: "sp_${System.currentTimeMillis()}_${Thread.currentThread().threadId()}" + return runCatching { + database.exec("SAVEPOINT $savepointName") + object : Savepoint { + override fun getSavepointId(): Int = 0 + + override fun getSavepointName(): String = savepointName + } + }.getOrElse { e -> + throw SQLExceptionMapper.mapException(e as? SQLException ?: SQLException(e.message, e)) + } + } + + override fun releaseSavepoint(savepoint: Savepoint) { + checkClosed() + runCatching { + database.exec("RELEASE SAVEPOINT ${savepoint.savepointName}") + }.onFailure { e -> + throw SQLExceptionMapper.mapException(e as? SQLException ?: SQLException(e.message, e)) + } + } + + override fun close() { + if (CLOSED.compareAndSet(this, false, true)) { + database.runCatching { + if (inTransaction) { + endTransaction() + } + }.onFailure { e -> + logger.warn("Error ending transaction on connection close: ${e.message}") + } + } + } + + override fun isClosed(): Boolean = closed + + override fun getMetaData(): DatabaseMetaData = _metaData + + override fun setReadOnly(readOnly: Boolean) { + checkClosed() + this.readOnly = readOnly + } + + override fun isReadOnly(): Boolean = readOnly + + override fun setCatalog(catalog: String?) { + checkClosed() + } + + override fun getCatalog(): String? = null + + override fun setTransactionIsolation(level: Int) { + checkClosed() + when (level) { + Connection.TRANSACTION_SERIALIZABLE -> transactionIsolation = level + else -> addWarning( + SQLWarning("SQLite only supports TRANSACTION_SERIALIZABLE isolation level, ignoring level: $level") + ) + } + } + + override fun getTransactionIsolation(): Int = transactionIsolation + + override fun getWarnings(): SQLWarning? = warnings.firstOrNull() + + override fun clearWarnings() { + warnings.clear() + } + + override fun getTypeMap(): MutableMap> = mutableMapOf() + + override fun setTypeMap(map: MutableMap>?) { + checkClosed() + } + + override fun setHoldability(holdability: Int) { + checkClosed() + when (holdability) { + ResultSet.CLOSE_CURSORS_AT_COMMIT -> Unit + ResultSet.HOLD_CURSORS_OVER_COMMIT -> addWarning( + SQLWarning("SQLite does not support holdable cursors, ignoring HOLD_CURSORS_OVER_COMMIT") + ) + else -> throw SQLException("Unsupported holdability: $holdability") + } + } + + override fun getHoldability(): Int = holdability + + override fun setClientInfo(name: String, value: String?) { + checkClosed() + addWarning(SQLWarning("SQLite does not support client info properties, ignoring: $name=$value")) + } + + override fun setClientInfo(properties: Properties?) { + checkClosed() + addWarning(SQLWarning("SQLite does not support client info properties, ignoring properties")) + } + + override fun getClientInfo(name: String): String? { + checkClosed() + return null + } + + override fun getClientInfo(): Properties { + checkClosed() + return Properties() + } + + override fun createArrayOf(typeName: String, elements: Array): java.sql.Array { + throw SQLFeatureNotSupportedException("SQLite does not support arrays") + } + + override fun createStruct(typeName: String, attributes: Array): Struct { + throw SQLFeatureNotSupportedException("SQLite does not support structs") + } + + override fun createClob(): Clob { + checkClosed() + return JdbcClob() + } + + override fun createBlob(): Blob { + throw SQLFeatureNotSupportedException("Use byte arrays instead of BLOBs with SQLite") + } + + override fun createNClob(): NClob { + throw SQLFeatureNotSupportedException("SQLite does not support NCLOBs") + } + + override fun createSQLXML(): SQLXML { + throw SQLFeatureNotSupportedException("SQLite does not support SQLXML") + } + + override fun isValid(timeout: Int): Boolean { + if (isClosed) { + return false + } + return runCatching { + if (timeout < 0) { + throw SQLException("Timeout must be non-negative") + } + database.exec("SELECT 1") + true + }.getOrElse { e -> + logger.warn("Connection validation failed: {}", e.message) + false + } + } + + override fun setNetworkTimeout(executor: Executor?, milliseconds: Int) { + checkClosed() + if (milliseconds < 0) { + throw SQLException("Network timeout must be non-negative") + } + addWarning(SQLWarning("SQLite does not support network timeouts, ignoring timeout: $milliseconds")) + } + + override fun getNetworkTimeout(): Int = networkTimeout + + override fun getSchema(): String? = null + + override fun setSchema(schema: String?) { + checkClosed() + } + + override fun abort(executor: Executor?) { + close() + } + + override fun unwrap(iface: Class): T = if (iface.isAssignableFrom(this::class.java)) { + @Suppress("UNCHECKED_CAST") + this as T + } else if (iface.isAssignableFrom(SQLDatabase::class.java)) { + @Suppress("UNCHECKED_CAST") + database as T + } else { + throw SQLException("Cannot unwrap to ${iface.name}") + } + + override fun isWrapperFor(iface: Class<*>): Boolean { + return iface.isAssignableFrom(this::class.java) || iface.isAssignableFrom(SQLDatabase::class.java) + } + + private fun checkClosed() { + if (isClosed) { + throw SQLException("Connection is closed") + } + } + + private fun applyConnectionProperties() { + runCatching { + val foreignKeys = properties.getProperty("foreignKeys")?.toBoolean() ?: true + database.exec("PRAGMA foreign_keys = ${if (foreignKeys) { 1 } else { 0 } }") + }.onFailure { e -> + throw SQLExceptionMapper.mapException(e as? SQLException ?: SQLException(e.message, e)) + } + } + + internal fun ensureTransaction() { + if (!autoCommit && !database.inTransaction) { + database.beginImmediateTransaction() + } + } + + internal fun addWarning(warning: SQLWarning) { + warnings.add(warning) + } +} diff --git a/selekt-jdbc/src/main/kotlin/com/bloomberg/selekt/jdbc/driver/SelektDataSource.kt b/selekt-jdbc/src/main/kotlin/com/bloomberg/selekt/jdbc/driver/SelektDataSource.kt new file mode 100644 index 0000000000..a72266dda9 --- /dev/null +++ b/selekt-jdbc/src/main/kotlin/com/bloomberg/selekt/jdbc/driver/SelektDataSource.kt @@ -0,0 +1,332 @@ +/* + * Copyright 2026 Bloomberg Finance L.P. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.bloomberg.selekt.jdbc.driver + +import com.bloomberg.selekt.DatabaseConfiguration +import com.bloomberg.selekt.SQLCode +import com.bloomberg.selekt.SQLDatabase +import com.bloomberg.selekt.SQLite +import com.bloomberg.selekt.SQLiteJournalMode +import com.bloomberg.selekt.externalSQLiteSingleton +import com.bloomberg.selekt.jdbc.connection.JdbcConnection +import com.bloomberg.selekt.jdbc.exception.SQLExceptionMapper +import com.bloomberg.selekt.jdbc.util.ConnectionURL +import java.io.File +import java.io.PrintWriter +import java.sql.Connection +import java.sql.SQLException +import java.util.Properties +import java.util.concurrent.ConcurrentHashMap +import java.util.logging.Logger as JulLogger +import java.lang.invoke.MethodHandles +import java.lang.invoke.VarHandle +import javax.sql.DataSource +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +@Suppress("TooGenericExceptionCaught") +class SelektDataSource : DataSource { + companion object { + private const val PROPERTY_ENCRYPT = "encrypt" + private const val PROPERTY_KEY = "key" + private const val PROPERTY_POOL_SIZE = "poolSize" + private const val PROPERTY_BUSY_TIMEOUT = "busyTimeout" + private const val PROPERTY_JOURNAL_MODE = "journalMode" + private const val PROPERTY_FOREIGN_KEYS = "foreignKeys" + private const val DEFAULT_POOL_SIZE = 10 + private const val HEX_PREFIX_LENGTH = 2 + private const val HEX_CHUNK_SIZE = 2 + private const val HEX_RADIX = 16 + + private val CLOSED: VarHandle = MethodHandles.lookup() + .findVarHandle(SelektDataSource::class.java, "closed", Boolean::class.javaPrimitiveType) + } + + private val logger: Logger = LoggerFactory.getLogger(SelektDataSource::class.java) + + @Volatile + private var closed = false + + @Volatile + private var url: String = "" + + @Volatile + var databasePath: String = "" + set(value) { + field = value + url = "jdbc:selekt:$value" + } + + @Volatile + var maxPoolSize: Int = DEFAULT_POOL_SIZE + set(value) { + require(value > 0) { "Pool size must be positive" } + field = value + } + + @Volatile + var busyTimeout: Int = DatabaseConfiguration.COMMON_BUSY_TIMEOUT_MILLIS + set(value) { + require(value >= 0) { "Busy timeout must be non-negative" } + field = value + } + + @Volatile + var journalMode: String = "WAL" + set(value) { + val isValidMode = try { + SQLiteJournalMode.valueOf(value.uppercase()) + true + } catch (e: IllegalArgumentException) { + logger.debug("Invalid journal mode value '{}': {}", value, e.message) + false + } + require(isValidMode) { "Invalid journal mode: $value" } + field = value + } + + @Volatile + var foreignKeys: Boolean = true + + @Volatile + var encryptionEnabled: Boolean = false + + @Volatile + var encryptionKey: String? = null + + @Volatile + private var loginTimeoutSeconds = 0 + + @Volatile + private var logWriter: PrintWriter? = null + + private val databaseCache = ConcurrentHashMap() + + override fun getConnection(): Connection = getConnection(null, null) + + override fun getConnection(username: String?, password: String?): Connection { + if (closed) { + throw SQLException("DataSource is closed") + } + return runCatching { + val connectionURL = buildConnectionURL() + val mergedProperties = buildConnectionProperties() + + val database = getOrCreateDatabase(connectionURL, mergedProperties) + JdbcConnection(database, connectionURL, mergedProperties) + }.getOrElse { e -> + throw SQLExceptionMapper.mapException( + "Failed to create connection: ${e.message}", + -1, + -1, + e + ) + } + } + + fun setEncryption(enabled: Boolean, key: String? = null) { + encryptionEnabled = enabled + encryptionKey = key + } + + fun close() { + if (CLOSED.compareAndSet(this, false, true)) { + var firstException: Throwable? = null + databaseCache.values.forEach { database -> + runCatching { + database.close() + }.onFailure { e -> + if (firstException == null) { + firstException = e + } else { + firstException.addSuppressed(e) + } + } + } + databaseCache.clear() + logger.info("SelektDataSource closed") + firstException?.let { throw it } + } + } + + fun isClosed(): Boolean = closed + + override fun getLogWriter(): PrintWriter? = logWriter + + override fun setLogWriter(out: PrintWriter?) { + logWriter = out + } + + override fun setLoginTimeout(seconds: Int) { + if (seconds < 0) { + throw SQLException("Login timeout must be non-negative") + } + loginTimeoutSeconds = seconds + } + + override fun getLoginTimeout(): Int = loginTimeoutSeconds + + override fun getParentLogger(): JulLogger = JulLogger.getLogger(SelektDataSource::class.java.name) + + override fun unwrap(iface: Class): T { + if (iface.isInstance(this)) { + @Suppress("UNCHECKED_CAST") + return this as T + } + throw SQLException("Cannot unwrap to ${iface.name}") + } + + override fun isWrapperFor(iface: Class<*>): Boolean = iface.isInstance(this) + + private fun buildConnectionURL(): ConnectionURL { + val effectiveUrl = if (url.isNotEmpty()) { + url + } else if (databasePath.isNotEmpty()) { + buildUrlFromProperties() + } else { + throw SQLException("No database path or URL specified") + } + + return ConnectionURL.parse(effectiveUrl) + } + + private fun buildUrlFromProperties(): String { + val baseUrl = "jdbc:selekt:$databasePath" + return mutableListOf().apply { + if (encryptionEnabled) { + add("encrypt=true") + encryptionKey?.let { add("key=$it") } + } + add("poolSize=$maxPoolSize") + add("busyTimeout=$busyTimeout") + add("journalMode=$journalMode") + add("foreignKeys=$foreignKeys") + }.run { + if (isEmpty()) { + baseUrl + } else { + "$baseUrl?${joinToString("&")}" + } + } + } + + private fun buildConnectionProperties(): Properties = Properties().apply { + setProperty(PROPERTY_ENCRYPT, encryptionEnabled.toString()) + encryptionKey?.let { setProperty(PROPERTY_KEY, it) } + setProperty(PROPERTY_POOL_SIZE, maxPoolSize.toString()) + setProperty(PROPERTY_BUSY_TIMEOUT, busyTimeout.toString()) + setProperty(PROPERTY_JOURNAL_MODE, journalMode) + setProperty(PROPERTY_FOREIGN_KEYS, foreignKeys.toString()) + } + + private fun getOrCreateDatabase( + connectionURL: ConnectionURL, + properties: Properties + ): SQLDatabase { + val cacheKey = buildCacheKey(connectionURL, properties) + return databaseCache.computeIfAbsent(cacheKey) { + createDatabase(connectionURL, properties) + } + } + + private fun createDatabase( + connectionURL: ConnectionURL, + properties: Properties + ): SQLDatabase { + val configuration = buildDatabaseConfiguration(properties) + val encryptionKeyBytes = getEncryptionKey(properties) + val sqlite = object : SQLite(externalSQLiteSingleton()) { + override fun throwSQLException( + code: SQLCode, + extendedCode: SQLCode, + message: String, + context: String? + ): Nothing { + throw SQLExceptionMapper.mapException(message, code, extendedCode) + } + } + return SQLDatabase( + path = connectionURL.databasePath, + sqlite = sqlite, + configuration = configuration, + key = encryptionKeyBytes, + random = com.bloomberg.selekt.CommonThreadLocalRandom + ) + } + + private fun buildDatabaseConfiguration(properties: Properties): DatabaseConfiguration { + val poolSizeValue = properties.getProperty(PROPERTY_POOL_SIZE)?.toIntOrNull() ?: maxPoolSize + val busyTimeoutValue = properties.getProperty(PROPERTY_BUSY_TIMEOUT)?.toIntOrNull() ?: busyTimeout + val journalModeValue = properties.getProperty(PROPERTY_JOURNAL_MODE)?.let { + SQLiteJournalMode.valueOf(it.uppercase()) + } ?: SQLiteJournalMode.valueOf(journalMode.uppercase()) + val baseConfig = journalModeValue.databaseConfiguration + return baseConfig.copy( + maxConnectionPoolSize = poolSizeValue, + busyTimeoutMillis = busyTimeoutValue + ) + } + + private fun getEncryptionKey(properties: Properties): ByteArray? { + val encrypt = properties.getProperty(PROPERTY_ENCRYPT)?.toBoolean() == true + val keyProperty = properties.getProperty(PROPERTY_KEY) + if (!encrypt || keyProperty == null) { + return null + } + return keyProperty.run { + when { + startsWith("0x") || startsWith("0X") -> parseHexKey(this) + else -> parseStringOrFileKey(this) + } + } + } + + private fun parseHexKey(keyProperty: String): ByteArray = keyProperty.substring(HEX_PREFIX_LENGTH) + .chunked(HEX_CHUNK_SIZE) + .map { + it.toInt(HEX_RADIX).toByte() + }.toByteArray() + + private fun parseStringOrFileKey(keyProperty: String): ByteArray = runCatching { + val file = File(keyProperty) + if (file.exists() && file.isFile) { + file.readBytes() + } else { + keyProperty.toByteArray(Charsets.UTF_8) + } + }.getOrElse { e -> + logger.debug("Failed to read key from file '{}', treating as string key: {}", keyProperty, e.message) + keyProperty.toByteArray(Charsets.UTF_8) + } + + private fun buildCacheKey( + connectionURL: ConnectionURL, + properties: Properties + ): String { + val propString = listOf( + PROPERTY_ENCRYPT, + PROPERTY_KEY, + PROPERTY_POOL_SIZE, + PROPERTY_BUSY_TIMEOUT, + PROPERTY_JOURNAL_MODE, + PROPERTY_FOREIGN_KEYS + ).mapNotNull { key -> + properties.getProperty(key)?.let { "$key=$it" } + }.sorted().joinToString("&") + return "${connectionURL.databasePath}?$propString" + } +} diff --git a/selekt-jdbc/src/main/kotlin/com/bloomberg/selekt/jdbc/driver/SelektDriver.kt b/selekt-jdbc/src/main/kotlin/com/bloomberg/selekt/jdbc/driver/SelektDriver.kt new file mode 100644 index 0000000000..a82273caef --- /dev/null +++ b/selekt-jdbc/src/main/kotlin/com/bloomberg/selekt/jdbc/driver/SelektDriver.kt @@ -0,0 +1,252 @@ +/* + * Copyright 2026 Bloomberg Finance L.P. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.bloomberg.selekt.jdbc.driver + +import com.bloomberg.selekt.DatabaseConfiguration +import com.bloomberg.selekt.SQLDatabase +import com.bloomberg.selekt.SQLiteJournalMode +import com.bloomberg.selekt.externalSQLiteSingleton +import com.bloomberg.selekt.jdbc.connection.JdbcConnection +import com.bloomberg.selekt.jdbc.exception.SQLExceptionMapper +import com.bloomberg.selekt.jdbc.util.ConnectionURL +import java.io.File +import java.sql.Connection +import java.sql.Driver +import java.sql.DriverManager +import java.sql.DriverPropertyInfo +import java.sql.SQLException +import java.util.Properties +import java.util.logging.Logger as JulLogger +import java.util.concurrent.ConcurrentHashMap +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +/** + * Supports the URL format: jdbc:selekt:path/to/database.sqlite[?properties] + * + * Supported connection properties: + * - encrypt: Enable SQLCipher encryption (true/false) + * - key: Encryption key (hex string or file path) + * - poolSize: Maximum connection pool size (integer) + * - busyTimeout: SQLite busy timeout in milliseconds (integer) + * - journalMode: SQLite journal mode (DELETE, WAL, MEMORY, etc.) + * - foreignKeys: Enable foreign key constraints (true/false) + */ +@Suppress("TooGenericExceptionCaught") +class SelektDriver : Driver { + companion object { + private val logger: Logger = LoggerFactory.getLogger(SelektDriver::class.java) + + const val DRIVER_NAME = "Selekt JDBC Driver" + const val DRIVER_VERSION = "4.3.0" + const val MAJOR_VERSION = 4 + const val MINOR_VERSION = 3 + + private const val PROPERTY_ENCRYPT = "encrypt" + private const val PROPERTY_KEY = "key" + private const val PROPERTY_POOL_SIZE = "poolSize" + private const val PROPERTY_BUSY_TIMEOUT = "busyTimeout" + private const val PROPERTY_JOURNAL_MODE = "journalMode" + private const val PROPERTY_FOREIGN_KEYS = "foreignKeys" + + private const val DEFAULT_POOL_SIZE = 10 + + private const val HEX_PREFIX_LENGTH = 2 + private const val HEX_CHUNK_SIZE = 2 + private const val HEX_RADIX = 16 + + private val BOOLEAN_CHOICES = arrayOf("true", "false") + + private val databaseCache = ConcurrentHashMap() + + init { + runCatching { + DriverManager.registerDriver(SelektDriver()) + logger.info("{} {} registered successfully", DRIVER_NAME, DRIVER_VERSION) + }.onFailure { e -> + logger.error("Failed to register {}: {}", DRIVER_NAME, e.message) + throw SQLException("Failed to register Selekt JDBC driver", e) + } + } + } + + override fun connect(url: String, info: Properties): Connection? = if (!acceptsURL(url)) { + null + } else { + runCatching { + val connectionURL = ConnectionURL.parse(url) + val mergedProperties = mergeProperties(connectionURL.properties, info) + val database = getOrCreateDatabase(connectionURL, mergedProperties) + JdbcConnection(database, connectionURL, mergedProperties) + }.getOrElse { e -> + throw SQLExceptionMapper.mapException( + "Failed to create connection to $url: ${e.message}", + -1, + -1, + e + ) + } + } + + override fun acceptsURL(url: String?): Boolean = url != null && ConnectionURL.isValidUrl(url) + + override fun getPropertyInfo(url: String, info: Properties): Array = if (!acceptsURL(url)) { + throw SQLException("Invalid URL format: $url") + } else { + arrayOf( + DriverPropertyInfo(PROPERTY_ENCRYPT, info.getProperty(PROPERTY_ENCRYPT, "false")).apply { + description = "Enable SQLCipher encryption" + required = false + choices = BOOLEAN_CHOICES + }, + DriverPropertyInfo(PROPERTY_KEY, info.getProperty(PROPERTY_KEY)).apply { + description = "Encryption key (hex string or file path)" + required = false + }, + DriverPropertyInfo(PROPERTY_POOL_SIZE, info.getProperty(PROPERTY_POOL_SIZE, "10")).apply { + description = "Maximum connection pool size" + required = false + }, + DriverPropertyInfo(PROPERTY_BUSY_TIMEOUT, info.getProperty(PROPERTY_BUSY_TIMEOUT, "30000")).apply { + description = "SQLite busy timeout in milliseconds" + required = false + }, + DriverPropertyInfo(PROPERTY_JOURNAL_MODE, info.getProperty(PROPERTY_JOURNAL_MODE, "WAL")).apply { + description = "SQLite journal mode" + required = false + choices = arrayOf("DELETE", "WAL", "MEMORY", "PERSIST", "TRUNCATE", "OFF") + }, + DriverPropertyInfo(PROPERTY_FOREIGN_KEYS, info.getProperty(PROPERTY_FOREIGN_KEYS, "true")).apply { + description = "Enable foreign key constraints" + required = false + choices = BOOLEAN_CHOICES + } + ) + } + + override fun getMajorVersion(): Int = MAJOR_VERSION + + override fun getMinorVersion(): Int = MINOR_VERSION + + override fun jdbcCompliant(): Boolean = false + + override fun getParentLogger(): JulLogger = JulLogger.getLogger(SelektDriver::class.java.name) + + private fun getOrCreateDatabase( + connectionURL: ConnectionURL, + properties: Properties + ): SQLDatabase = databaseCache.computeIfAbsent(buildCacheKey(connectionURL, properties)) { + createDatabase(connectionURL, properties) + } + + private fun createDatabase( + connectionURL: ConnectionURL, + properties: Properties + ): SQLDatabase { + val configuration = buildDatabaseConfiguration(properties) + val encryptionKey = getEncryptionKey(properties) + val sqlite = object : com.bloomberg.selekt.SQLite( + externalSQLiteSingleton() + ) { + override fun throwSQLException( + code: com.bloomberg.selekt.SQLCode, + extendedCode: com.bloomberg.selekt.SQLCode, + message: String, + context: String? + ): Nothing { + throw SQLExceptionMapper.mapException(message, code, extendedCode) + } + } + return SQLDatabase( + path = connectionURL.databasePath, + sqlite = sqlite, + configuration = configuration, + key = encryptionKey, + random = com.bloomberg.selekt.CommonThreadLocalRandom + ) + } + + private fun buildDatabaseConfiguration(properties: Properties): DatabaseConfiguration { + val poolSize = properties.getProperty(PROPERTY_POOL_SIZE)?.toIntOrNull() ?: DEFAULT_POOL_SIZE + val busyTimeout = properties.getProperty(PROPERTY_BUSY_TIMEOUT)?.toIntOrNull() + ?: DatabaseConfiguration.COMMON_BUSY_TIMEOUT_MILLIS + val journalMode = properties.getProperty(PROPERTY_JOURNAL_MODE)?.let { + SQLiteJournalMode.valueOf(it.uppercase()) + } ?: SQLiteJournalMode.WAL + val baseConfig = journalMode.databaseConfiguration + return baseConfig.copy( + maxConnectionPoolSize = poolSize, + busyTimeoutMillis = busyTimeout + ) + } + + private fun getEncryptionKey(properties: Properties): ByteArray? { + val encrypt = properties.getProperty(PROPERTY_ENCRYPT)?.toBoolean() == true + val keyProperty = properties.getProperty(PROPERTY_KEY) + if (!encrypt || keyProperty == null) { + return null + } + return when { + keyProperty.startsWith("0x") || keyProperty.startsWith("0X") -> + parseHexKey(keyProperty) + else -> parseStringOrFileKey(keyProperty) + } + } + + private fun parseHexKey(keyProperty: String): ByteArray = keyProperty.substring(HEX_PREFIX_LENGTH) + .chunked(HEX_CHUNK_SIZE) + .map { + it.toInt(HEX_RADIX).toByte() + }.toByteArray() + + private fun parseStringOrFileKey(keyProperty: String): ByteArray = runCatching { + val file = File(keyProperty) + if (file.exists() && file.isFile) { + file.readBytes() + } else { + keyProperty.toByteArray(Charsets.UTF_8) + } + }.getOrElse { e -> + logger.debug("Failed to read key from file '{}', treating as string key: {}", keyProperty, e.message) + keyProperty.toByteArray(Charsets.UTF_8) + } + + private fun mergeProperties( + urlProperties: Properties, + additionalProperties: Properties + ): Properties = Properties().apply { + putAll(urlProperties) + putAll(additionalProperties) + } + + private fun buildCacheKey( + connectionURL: ConnectionURL, + properties: Properties + ): String { + val propString = listOf( + PROPERTY_ENCRYPT, + PROPERTY_KEY, + PROPERTY_POOL_SIZE, + PROPERTY_BUSY_TIMEOUT, + PROPERTY_JOURNAL_MODE, + PROPERTY_FOREIGN_KEYS + ).mapNotNull { key -> + properties.getProperty(key)?.let { "$key=$it" } + }.sorted().joinToString("&") + return "${connectionURL.databasePath}?$propString" + } +} diff --git a/selekt-jdbc/src/main/kotlin/com/bloomberg/selekt/jdbc/exception/SQLExceptionMapper.kt b/selekt-jdbc/src/main/kotlin/com/bloomberg/selekt/jdbc/exception/SQLExceptionMapper.kt new file mode 100644 index 0000000000..17c41cbb24 --- /dev/null +++ b/selekt-jdbc/src/main/kotlin/com/bloomberg/selekt/jdbc/exception/SQLExceptionMapper.kt @@ -0,0 +1,294 @@ +/* + * Copyright 2026 Bloomberg Finance L.P. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.bloomberg.selekt.jdbc.exception + +import com.bloomberg.selekt.SQL_ABORT +import com.bloomberg.selekt.SQL_ABORT_ROLLBACK +import com.bloomberg.selekt.SQL_AUTH +import com.bloomberg.selekt.SQL_BUSY +import com.bloomberg.selekt.SQL_CANT_OPEN +import com.bloomberg.selekt.SQL_CONSTRAINT +import com.bloomberg.selekt.SQL_CORRUPT +import com.bloomberg.selekt.SQL_DONE +import com.bloomberg.selekt.SQL_ERROR +import com.bloomberg.selekt.SQL_FULL +import com.bloomberg.selekt.SQL_IO_ERROR +import com.bloomberg.selekt.SQL_IO_ERROR_ACCESS +import com.bloomberg.selekt.SQL_IO_ERROR_BLOCKED +import com.bloomberg.selekt.SQL_IO_ERROR_CHECK_RESERVED_LOCK +import com.bloomberg.selekt.SQL_IO_ERROR_CLOSE +import com.bloomberg.selekt.SQL_IO_ERROR_CONVPATH +import com.bloomberg.selekt.SQL_IO_ERROR_DELETE +import com.bloomberg.selekt.SQL_IO_ERROR_DELETE_NO_ENT +import com.bloomberg.selekt.SQL_IO_ERROR_DIR_CLOSE +import com.bloomberg.selekt.SQL_IO_ERROR_DIR_FSYNC +import com.bloomberg.selekt.SQL_IO_ERROR_DIR_FSTAT +import com.bloomberg.selekt.SQL_IO_ERROR_FSYNC +import com.bloomberg.selekt.SQL_IO_ERROR_GET_TEMP_PATH +import com.bloomberg.selekt.SQL_IO_ERROR_MMAP +import com.bloomberg.selekt.SQL_IO_ERROR_RDLOCK +import com.bloomberg.selekt.SQL_IO_ERROR_SEEK +import com.bloomberg.selekt.SQL_IO_ERROR_SHMLOCK +import com.bloomberg.selekt.SQL_IO_ERROR_SHMMAP +import com.bloomberg.selekt.SQL_IO_ERROR_SHMOPEN +import com.bloomberg.selekt.SQL_IO_ERROR_SHMSIZE +import com.bloomberg.selekt.SQL_IO_ERROR_SHORT_READ +import com.bloomberg.selekt.SQL_IO_ERROR_TRUNCATE +import com.bloomberg.selekt.SQL_IO_ERROR_WRITE +import com.bloomberg.selekt.SQL_IO_ERROR_LOCK +import com.bloomberg.selekt.SQL_IO_ERROR_NOMEM +import com.bloomberg.selekt.SQL_IO_ERROR_READ +import com.bloomberg.selekt.SQL_IO_ERROR_UNLOCK +import com.bloomberg.selekt.SQL_LOCKED +import com.bloomberg.selekt.SQL_LOCKED_SHARED_CACHE +import com.bloomberg.selekt.SQL_LOCKED_VTAB +import com.bloomberg.selekt.SQL_MISMATCH +import com.bloomberg.selekt.SQL_MISUSE +import com.bloomberg.selekt.SQL_NOMEM +import com.bloomberg.selekt.SQL_NOT_A_DATABASE +import com.bloomberg.selekt.SQL_NOT_FOUND +import com.bloomberg.selekt.SQL_NOTICE_RECOVER_ROLLBACK +import com.bloomberg.selekt.SQL_NOTICE_RECOVER_WAL +import com.bloomberg.selekt.SQL_OK +import com.bloomberg.selekt.SQL_OK_LOAD_PERMANENTLY +import com.bloomberg.selekt.SQL_RANGE +import com.bloomberg.selekt.SQL_READONLY +import com.bloomberg.selekt.SQL_READONLY_CANT_INIT +import com.bloomberg.selekt.SQL_READONLY_CANT_LOCK +import com.bloomberg.selekt.SQL_READONLY_DB_MOVED +import com.bloomberg.selekt.SQL_READONLY_DIRECTORY +import com.bloomberg.selekt.SQL_READONLY_RECOVERY +import com.bloomberg.selekt.SQL_READONLY_ROLLBACK +import com.bloomberg.selekt.SQL_ROW +import com.bloomberg.selekt.SQL_TOO_BIG +import com.bloomberg.selekt.SQL_WARNING_AUTOINDEX +import com.bloomberg.selekt.SQLCode +import java.sql.SQLException +import java.sql.SQLDataException +import java.sql.SQLIntegrityConstraintViolationException +import java.sql.SQLNonTransientConnectionException +import java.sql.SQLNonTransientException +import java.sql.SQLRecoverableException +import java.sql.SQLTimeoutException +import java.sql.SQLTransactionRollbackException +import java.sql.SQLTransientConnectionException +import java.sql.SQLTransientException + +internal object SQLExceptionMapper { + private const val SQLSTATE_HY000 = "HY000" + private const val SQLSTATE_53000 = "53000" + private const val SQLSTATE_40001 = "40001" + private const val SQLSTATE_08007 = "08007" + + @JvmStatic + fun mapException(selektException: SQLException): SQLException = mapException( + selektException.message ?: "Unknown error", + extractSQLCode(selektException), + extractExtendedSQLCode(selektException), + selektException + ) + + @JvmStatic + fun mapException( + message: String, + sqlCode: SQLCode, + extendedSQLCode: SQLCode = -1, + cause: Throwable? = null + ): SQLException { + val (exceptionClass, sqlState) = mapSQLCodeToExceptionClass(sqlCode, extendedSQLCode) + val enhancedMessage = buildMessage(message, sqlCode, extendedSQLCode) + return when (exceptionClass) { + ExceptionType.DATA_EXCEPTION -> SQLDataException(enhancedMessage, sqlState, sqlCode, cause) + ExceptionType.INTEGRITY_CONSTRAINT_VIOLATION -> SQLIntegrityConstraintViolationException( + enhancedMessage, + sqlState, + sqlCode, + cause + ) + ExceptionType.NON_TRANSIENT_CONNECTION -> SQLNonTransientConnectionException( + enhancedMessage, + sqlState, + sqlCode, + cause + ) + ExceptionType.NON_TRANSIENT -> SQLNonTransientException(enhancedMessage, sqlState, sqlCode, cause) + ExceptionType.RECOVERABLE -> SQLRecoverableException(enhancedMessage, sqlState, sqlCode, cause) + ExceptionType.TIMEOUT -> SQLTimeoutException(enhancedMessage, sqlState, sqlCode, cause) + ExceptionType.TRANSACTION_ROLLBACK -> SQLTransactionRollbackException(enhancedMessage, sqlState, sqlCode, cause) + ExceptionType.TRANSIENT_CONNECTION -> SQLTransientConnectionException(enhancedMessage, sqlState, sqlCode, cause) + ExceptionType.TRANSIENT -> SQLTransientException(enhancedMessage, sqlState, sqlCode, cause) + ExceptionType.GENERIC -> SQLException(enhancedMessage, sqlState, sqlCode, cause) + } + } + + private enum class ExceptionType { + DATA_EXCEPTION, + INTEGRITY_CONSTRAINT_VIOLATION, + NON_TRANSIENT_CONNECTION, + NON_TRANSIENT, + RECOVERABLE, + TIMEOUT, + TRANSACTION_ROLLBACK, + TRANSIENT_CONNECTION, + TRANSIENT, + GENERIC + } + + private fun mapSQLCodeToExceptionClass( + sqlCode: SQLCode, + extendedSQLCode: SQLCode + ): Pair = when (sqlCode) { + SQL_CONSTRAINT -> ExceptionType.INTEGRITY_CONSTRAINT_VIOLATION to "23000" + SQL_MISMATCH -> ExceptionType.DATA_EXCEPTION to "22000" + SQL_TOO_BIG -> ExceptionType.DATA_EXCEPTION to "22001" + SQL_RANGE -> ExceptionType.DATA_EXCEPTION to "22003" + SQL_CANT_OPEN -> ExceptionType.NON_TRANSIENT_CONNECTION to "08001" + SQL_NOT_A_DATABASE -> ExceptionType.NON_TRANSIENT_CONNECTION to SQLSTATE_08007 + SQL_CORRUPT -> ExceptionType.NON_TRANSIENT_CONNECTION to SQLSTATE_08007 + SQL_AUTH -> ExceptionType.NON_TRANSIENT_CONNECTION to "28000" + SQL_BUSY -> when { + isTimeoutRelated(extendedSQLCode) -> ExceptionType.TIMEOUT to "HYT00" + else -> ExceptionType.TRANSIENT to SQLSTATE_40001 + } + SQL_LOCKED, SQL_LOCKED_SHARED_CACHE, SQL_LOCKED_VTAB -> + ExceptionType.TRANSIENT to SQLSTATE_40001 + + SQL_ABORT, SQL_ABORT_ROLLBACK -> ExceptionType.TRANSACTION_ROLLBACK to "40000" + SQL_IO_ERROR -> when (extendedSQLCode) { + SQL_IO_ERROR_NOMEM -> ExceptionType.RECOVERABLE to SQLSTATE_53000 + SQL_IO_ERROR_ACCESS, SQL_IO_ERROR_LOCK, SQL_IO_ERROR_UNLOCK -> + ExceptionType.TRANSIENT to SQLSTATE_HY000 + else -> ExceptionType.NON_TRANSIENT to SQLSTATE_HY000 + } + SQL_NOMEM -> ExceptionType.RECOVERABLE to SQLSTATE_53000 + SQL_FULL -> ExceptionType.NON_TRANSIENT to "53100" + SQL_READONLY -> ExceptionType.NON_TRANSIENT to "25006" + SQL_MISUSE -> ExceptionType.NON_TRANSIENT to "HY010" + SQL_NOT_FOUND -> ExceptionType.NON_TRANSIENT to "42000" + SQL_ERROR -> ExceptionType.NON_TRANSIENT to SQLSTATE_HY000 + SQL_OK, SQL_ROW, SQL_DONE -> ExceptionType.GENERIC to "00000" + else -> ExceptionType.GENERIC to SQLSTATE_HY000 + } + + private fun isTimeoutRelated(extendedSQLCode: SQLCode): Boolean { + return when (extendedSQLCode) { + SQL_IO_ERROR_BLOCKED -> true + else -> false + } + } + + private fun buildMessage(message: String, sqlCode: SQLCode, extendedSQLCode: SQLCode): String { + val codeDescription = getSQLCodeDescription(sqlCode) + val extendedDescription = if (extendedSQLCode != -1) { + getExtendedSQLCodeDescription(extendedSQLCode) + } else { + null + } + return buildString { + append(message) + if (codeDescription.isNotEmpty()) { + append(" (").append(codeDescription) + if (extendedDescription != null) { + append("; ").append(extendedDescription) + } + append(")") + } + } + } + + private fun getSQLCodeDescription(sqlCode: SQLCode): String = when (sqlCode) { + SQL_OK -> "SQLITE_OK" + SQL_ERROR -> "SQLITE_ERROR" + SQL_ABORT -> "SQLITE_ABORT" + SQL_BUSY -> "SQLITE_BUSY" + SQL_LOCKED -> "SQLITE_LOCKED" + SQL_NOMEM -> "SQLITE_NOMEM" + SQL_READONLY -> "SQLITE_READONLY" + SQL_IO_ERROR -> "SQLITE_IOERR" + SQL_CORRUPT -> "SQLITE_CORRUPT" + SQL_NOT_FOUND -> "SQLITE_NOTFOUND" + SQL_FULL -> "SQLITE_FULL" + SQL_CANT_OPEN -> "SQLITE_CANTOPEN" + SQL_TOO_BIG -> "SQLITE_TOOBIG" + SQL_CONSTRAINT -> "SQLITE_CONSTRAINT" + SQL_MISMATCH -> "SQLITE_MISMATCH" + SQL_MISUSE -> "SQLITE_MISUSE" + SQL_AUTH -> "SQLITE_AUTH" + SQL_RANGE -> "SQLITE_RANGE" + SQL_NOT_A_DATABASE -> "SQLITE_NOTADB" + SQL_ROW -> "SQLITE_ROW" + SQL_DONE -> "SQLITE_DONE" + else -> "SQLITE_UNKNOWN($sqlCode)" + } + + private fun getExtendedSQLCodeDescription(extendedSQLCode: SQLCode): String = when (extendedSQLCode) { + SQL_ABORT_ROLLBACK -> "SQLITE_ABORT_ROLLBACK" + SQL_IO_ERROR_ACCESS -> "SQLITE_IOERR_ACCESS" + SQL_IO_ERROR_BLOCKED -> "SQLITE_IOERR_BLOCKED" + SQL_IO_ERROR_CHECK_RESERVED_LOCK -> "SQLITE_IOERR_CHECKRESERVEDLOCK" + SQL_IO_ERROR_CLOSE -> "SQLITE_IOERR_CLOSE" + SQL_IO_ERROR_CONVPATH -> "SQLITE_IOERR_CONVPATH" + SQL_IO_ERROR_DELETE -> "SQLITE_IOERR_DELETE" + SQL_IO_ERROR_DELETE_NO_ENT -> "SQLITE_IOERR_DELETE_NOENT" + SQL_IO_ERROR_DIR_CLOSE -> "SQLITE_IOERR_DIR_CLOSE" + SQL_IO_ERROR_DIR_FSYNC -> "SQLITE_IOERR_DIR_FSYNC" + SQL_IO_ERROR_DIR_FSTAT -> "SQLITE_IOERR_DIR_FSTAT" + SQL_IO_ERROR_FSYNC -> "SQLITE_IOERR_FSYNC" + SQL_IO_ERROR_GET_TEMP_PATH -> "SQLITE_IOERR_GETTEMPPATH" + SQL_IO_ERROR_LOCK -> "SQLITE_IOERR_LOCK" + SQL_IO_ERROR_MMAP -> "SQLITE_IOERR_MMAP" + SQL_IO_ERROR_NOMEM -> "SQLITE_IOERR_NOMEM" + SQL_IO_ERROR_RDLOCK -> "SQLITE_IOERR_RDLOCK" + SQL_IO_ERROR_READ -> "SQLITE_IOERR_READ" + SQL_IO_ERROR_SEEK -> "SQLITE_IOERR_SEEK" + SQL_IO_ERROR_SHMLOCK -> "SQLITE_IOERR_SHMLOCK" + SQL_IO_ERROR_SHMMAP -> "SQLITE_IOERR_SHMMAP" + SQL_IO_ERROR_SHMOPEN -> "SQLITE_IOERR_SHMOPEN" + SQL_IO_ERROR_SHMSIZE -> "SQLITE_IOERR_SHMSIZE" + SQL_IO_ERROR_SHORT_READ -> "SQLITE_IOERR_SHORT_READ" + SQL_IO_ERROR_TRUNCATE -> "SQLITE_IOERR_TRUNCATE" + SQL_IO_ERROR_UNLOCK -> "SQLITE_IOERR_UNLOCK" + SQL_IO_ERROR_WRITE -> "SQLITE_IOERR_WRITE" + SQL_LOCKED_SHARED_CACHE -> "SQLITE_LOCKED_SHAREDCACHE" + SQL_LOCKED_VTAB -> "SQLITE_LOCKED_VTAB" + SQL_NOTICE_RECOVER_ROLLBACK -> "SQLITE_NOTICE_RECOVER_ROLLBACK" + SQL_NOTICE_RECOVER_WAL -> "SQLITE_NOTICE_RECOVER_WAL" + SQL_OK_LOAD_PERMANENTLY -> "SQLITE_OK_LOAD_PERMANENTLY" + SQL_READONLY_CANT_INIT -> "SQLITE_READONLY_CANTINIT" + SQL_READONLY_CANT_LOCK -> "SQLITE_READONLY_CANTLOCK" + SQL_READONLY_DB_MOVED -> "SQLITE_READONLY_DBMOVED" + SQL_READONLY_DIRECTORY -> "SQLITE_READONLY_DIRECTORY" + SQL_READONLY_RECOVERY -> "SQLITE_READONLY_RECOVERY" + SQL_READONLY_ROLLBACK -> "SQLITE_READONLY_ROLLBACK" + SQL_WARNING_AUTOINDEX -> "SQLITE_WARNING_AUTOINDEX" + else -> "SQLITE_UNKNOWN_EXTENDED($extendedSQLCode)" + } + + private fun extractSQLCode(exception: SQLException): SQLCode { + val message = exception.message ?: "" + val codePattern = Regex("Code: (\\d+)") + val match = codePattern.find(message) + return match?.groups?.get(1)?.value?.toIntOrNull() ?: exception.errorCode + } + + private fun extractExtendedSQLCode(exception: SQLException): SQLCode { + val message = exception.message ?: "" + val extendedPattern = Regex("Extended: (\\d+)") + val match = extendedPattern.find(message) + return match?.groups?.get(1)?.value?.toIntOrNull() ?: -1 + } +} diff --git a/selekt-jdbc/src/main/kotlin/com/bloomberg/selekt/jdbc/lob/JdbcClob.kt b/selekt-jdbc/src/main/kotlin/com/bloomberg/selekt/jdbc/lob/JdbcClob.kt new file mode 100644 index 0000000000..020ba2c017 --- /dev/null +++ b/selekt-jdbc/src/main/kotlin/com/bloomberg/selekt/jdbc/lob/JdbcClob.kt @@ -0,0 +1,236 @@ +/* + * Copyright 2026 Bloomberg Finance L.P. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.bloomberg.selekt.jdbc.lob + +import java.io.ByteArrayInputStream +import java.io.InputStream +import java.io.OutputStream +import java.io.Reader +import java.io.StringReader +import java.io.Writer +import java.sql.Clob +import java.sql.SQLException +import java.util.concurrent.atomic.AtomicBoolean +import javax.annotation.concurrent.NotThreadSafe + +@Suppress("Detekt.StringLiteralDuplication") +@NotThreadSafe +internal class JdbcClob : Clob { + private val content = StringBuilder() + private val freed = AtomicBoolean(false) + + constructor() + + constructor(initialContent: String) { + content.append(initialContent) + } + + override fun length(): Long { + checkNotFreed() + return content.length.toLong() + } + + override fun getSubString(pos: Long, length: Int): String { + checkNotFreed() + if (pos < 1) { + throw SQLException("Position must be >= 1 (got $pos)") + } else if (length < 0) { + throw SQLException("Length must be non-negative (got $length)") + } + val startIndex = (pos - 1).toInt() + if (startIndex < 0 || startIndex > content.length) { + throw SQLException("Position $pos is out of bounds (length=${content.length})") + } + val endIndex = minOf(startIndex + length, content.length) + return content.substring(startIndex, endIndex) + } + + override fun getCharacterStream(): Reader { + checkNotFreed() + return StringReader(content.toString()) + } + + override fun getCharacterStream(pos: Long, length: Long): Reader { + checkNotFreed() + val substring = getSubString(pos, length.toInt()) + return StringReader(substring) + } + + override fun getAsciiStream(): InputStream { + checkNotFreed() + return ByteArrayInputStream(content.toString().toByteArray(Charsets.US_ASCII)) + } + + override fun position(searchstr: String, start: Long): Long { + checkNotFreed() + if (start < 1) { + throw SQLException("Start position must be >= 1 (got $start)") + } + val startIndex = (start - 1).toInt() + if (startIndex >= content.length) { + return -1L + } + val index = content.indexOf(searchstr, startIndex) + return if (index >= 0) { + (index + 1).toLong() + } else { + -1L + } + } + + override fun position(searchstr: Clob, start: Long): Long { + checkNotFreed() + val searchString = searchstr.getSubString(1, searchstr.length().toInt()) + return position(searchString, start) + } + + override fun setString(pos: Long, str: String): Int { + checkNotFreed() + return setString(pos, str, 0, str.length) + } + + override fun setString(pos: Long, str: String, offset: Int, len: Int): Int { + checkNotFreed() + if (pos < 1) { + throw SQLException("Position must be >= 1 (got $pos)") + } else if (offset < 0 || offset > str.length) { + throw SQLException("Offset $offset is out of bounds for string of length ${str.length}") + } else if (len < 0 || offset + len > str.length) { + throw SQLException("Length $len with offset $offset exceeds string length ${str.length}") + } + val startIndex = (pos - 1).toInt() + val substring = str.substring(offset, offset + len) + while (content.length < startIndex) { + content.append(' ') + } + if (startIndex < content.length) { + val endIndex = minOf(startIndex + len, content.length) + content.replace(startIndex, endIndex, substring) + if (len > endIndex - startIndex) { + content.append(substring.substring(endIndex - startIndex)) + } + } else { + content.append(substring) + } + return len + } + + override fun setCharacterStream(pos: Long): Writer { + checkNotFreed() + if (pos < 1) { + throw SQLException("Position must be >= 1 (got $pos)") + } + val startIndex = (pos - 1).toInt() + return object : Writer() { + override fun write(cbuf: CharArray, off: Int, len: Int) { + checkNotFreed() + while (content.length < startIndex) { + content.append(' ') + } + val str = String(cbuf, off, len) + if (content.length == startIndex) { + content.append(str) + } else { + val endIndex = minOf(startIndex + len, content.length) + content.replace(startIndex, endIndex, str) + if (len > endIndex - startIndex) { + content.append(str.substring(endIndex - startIndex)) + } + } + } + + override fun flush() = Unit + + override fun close() = Unit + } + } + + override fun setAsciiStream(pos: Long): OutputStream { + checkNotFreed() + if (pos < 1) { + throw SQLException("Position must be >= 1 (got $pos)") + } + val startIndex = (pos - 1).toInt() + return object : OutputStream() { + private var currentPos = startIndex + + override fun write(b: Int) { + checkNotFreed() + while (content.length < currentPos) { + content.append(' ') + } + val char = b.toChar() + if (currentPos < content.length) { + content.setCharAt(currentPos, char) + } else { + content.append(char) + } + ++currentPos + } + + override fun write(b: ByteArray, off: Int, len: Int) { + checkNotFreed() + val str = String(b, off, len, Charsets.US_ASCII) + while (content.length < currentPos) { + content.append(' ') + } + if (currentPos < content.length) { + val endIndex = minOf(currentPos + len, content.length) + content.replace(currentPos, endIndex, str) + if (len > endIndex - currentPos) { + content.append(str.substring(endIndex - currentPos)) + } + } else { + content.append(str) + } + currentPos += len + } + + override fun flush() = Unit + + override fun close() = Unit + } + } + + override fun truncate(len: Long) { + checkNotFreed() + if (len < 0) { + throw SQLException("Length must be non-negative (got $len)") + } else if (len >= content.length) { + return + } + content.setLength(len.toInt()) + } + + override fun free() { + if (freed.compareAndSet(false, true)) { + content.clear() + content.trimToSize() + } + } + + private fun checkNotFreed() { + if (freed.get()) { + throw SQLException("Clob has been freed") + } + } + + internal fun asString(): String { + checkNotFreed() + return content.toString() + } +} diff --git a/selekt-jdbc/src/main/kotlin/com/bloomberg/selekt/jdbc/metadata/JdbcDatabaseMetaData.kt b/selekt-jdbc/src/main/kotlin/com/bloomberg/selekt/jdbc/metadata/JdbcDatabaseMetaData.kt new file mode 100644 index 0000000000..e601b498c5 --- /dev/null +++ b/selekt-jdbc/src/main/kotlin/com/bloomberg/selekt/jdbc/metadata/JdbcDatabaseMetaData.kt @@ -0,0 +1,806 @@ +/* + * Copyright 2026 Bloomberg Finance L.P. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.bloomberg.selekt.jdbc.metadata + +import com.bloomberg.selekt.SQLDatabase +import com.bloomberg.selekt.jdbc.connection.JdbcConnection +import com.bloomberg.selekt.jdbc.result.JdbcResultSet +import com.bloomberg.selekt.jdbc.util.ConnectionURL +import java.sql.Connection +import java.sql.DatabaseMetaData +import java.sql.ResultSet +import java.sql.RowIdLifetime +import java.sql.SQLException +import java.sql.SQLFeatureNotSupportedException +import java.sql.Types +import javax.annotation.concurrent.NotThreadSafe + +@Suppress("Detekt.LargeClass") +@NotThreadSafe +internal class JdbcDatabaseMetaData( + private val connection: JdbcConnection, + private val database: SQLDatabase, + private val connectionURL: ConnectionURL +) : DatabaseMetaData { + override fun getConnection(): Connection = connection + + override fun getDatabaseProductName(): String = "SQLite" + + override fun getDatabaseProductVersion(): String = "3.51.2" + + override fun getDatabaseMajorVersion(): Int = 3 + + override fun getDatabaseMinorVersion(): Int = 51 + + override fun getDriverName(): String = "Selekt JDBC Driver" + + override fun getDriverVersion(): String = "4.3" + + override fun getDriverMajorVersion(): Int = 4 + + override fun getDriverMinorVersion(): Int = 3 + + override fun getJDBCMajorVersion(): Int = 4 + + override fun getJDBCMinorVersion(): Int = 3 + + override fun getSQLKeywords(): String = "ABORT,AUTOINCREMENT,CONFLICT,FAIL,GLOB,IGNORE,ISNULL,NOTNULL,OFFSET," + + "PRAGMA,RAISE,REPLACE,VACUUM" + + override fun getSearchStringEscape(): String = "\\" + + override fun getMaxCatalogNameLength(): Int = 0 + + override fun getNumericFunctions(): String = "ABS,HEX,LENGTH,LOWER,LTRIM,MAX,MIN,NULLIF,QUOTE,RANDOM,REPLACE,ROUND," + + "RTRIM,SUBSTR,TRIM,TYPEOF,UPPER" + + override fun getStringFunctions(): String = "LENGTH,LOWER,LTRIM,REPLACE,RTRIM,SUBSTR,TRIM,UPPER" + + override fun getSystemFunctions(): String = "COALESCE,IFNULL,LAST_INSERT_ROWID,NULLIF,SQLITE_VERSION" + + override fun getTimeDateFunctions(): String = "DATE,DATETIME,JULIANDAY,STRFTIME,TIME" + + override fun getIdentifierQuoteString(): String = "\"" + + override fun getSQLStateType(): Int = DatabaseMetaData.sqlStateSQL99 + + override fun getExtraNameCharacters(): String = "" + + override fun isCatalogAtStart(): Boolean = false + + override fun getCatalogSeparator(): String = "." + + override fun getCatalogTerm(): String = "" + + override fun getSchemaTerm(): String = "" + + override fun getProcedureTerm(): String = "table" + + override fun getMaxBinaryLiteralLength(): Int = 0 + + override fun getMaxCharLiteralLength(): Int = 0 + + override fun getMaxColumnNameLength(): Int = 0 + + override fun getMaxColumnsInGroupBy(): Int = 0 + + override fun getMaxColumnsInIndex(): Int = 0 + + override fun getMaxColumnsInOrderBy(): Int = 0 + + override fun getMaxColumnsInSelect(): Int = 0 + + override fun getMaxColumnsInTable(): Int = 0 + + override fun getMaxConnections(): Int = 0 + + override fun getMaxCursorNameLength(): Int = 0 + + override fun getMaxIndexLength(): Int = 0 + + override fun getMaxProcedureNameLength(): Int = 0 + + override fun getMaxRowSize(): Int = 0 + + override fun getMaxSchemaNameLength(): Int = 0 + + override fun getMaxStatementLength(): Int = 0 + + override fun getMaxStatements(): Int = 0 + + override fun getMaxTableNameLength(): Int = 0 + + override fun getMaxTablesInSelect(): Int = 0 + + override fun getMaxUserNameLength(): Int = 0 + + override fun supportsAlterTableWithAddColumn(): Boolean = true + + override fun supportsAlterTableWithDropColumn(): Boolean = true + + override fun supportsANSI92EntryLevelSQL(): Boolean = true + + override fun supportsANSI92FullSQL(): Boolean = false + + override fun supportsANSI92IntermediateSQL(): Boolean = false + + override fun supportsBatchUpdates(): Boolean = true + + override fun supportsCatalogsInDataManipulation(): Boolean = false + + override fun supportsCatalogsInIndexDefinitions(): Boolean = false + + override fun supportsCatalogsInPrivilegeDefinitions(): Boolean = false + + override fun supportsCatalogsInProcedureCalls(): Boolean = false + + override fun supportsCatalogsInTableDefinitions(): Boolean = false + + override fun supportsColumnAliasing(): Boolean = false + + override fun supportsConvert(): Boolean = true + + override fun supportsConvert(fromType: Int, toType: Int): Boolean = true + + override fun supportsCorrelatedSubqueries(): Boolean = true + + @Suppress("FunctionMaxLength") + override fun supportsDataDefinitionAndDataManipulationTransactions(): Boolean = true + + override fun supportsDataManipulationTransactionsOnly(): Boolean = false + override fun supportsDifferentTableCorrelationNames(): Boolean = true + + override fun supportsExpressionsInOrderBy(): Boolean = true + + override fun supportsExtendedSQLGrammar(): Boolean = false + + override fun supportsFullOuterJoins(): Boolean = true + + override fun supportsGroupBy(): Boolean = true + + override fun supportsGroupByBeyondSelect(): Boolean = true + + override fun supportsGroupByUnrelated(): Boolean = true + + override fun supportsIntegrityEnhancementFacility(): Boolean = false + + override fun supportsLikeEscapeClause(): Boolean = true + + override fun supportsLimitedOuterJoins(): Boolean = true + + override fun supportsMinimumSQLGrammar(): Boolean = true + + override fun supportsMixedCaseIdentifiers(): Boolean = false + + override fun supportsMixedCaseQuotedIdentifiers(): Boolean = true + + override fun supportsMultipleOpenResults(): Boolean = false + + override fun supportsMultipleResultSets(): Boolean = false + + override fun supportsMultipleTransactions(): Boolean = false + + override fun supportsNonNullableColumns(): Boolean = true + + override fun supportsOpenCursorsAcrossCommit(): Boolean = false + + override fun supportsOpenCursorsAcrossRollback(): Boolean = false + + override fun supportsOpenStatementsAcrossCommit(): Boolean = false + + override fun supportsOpenStatementsAcrossRollback(): Boolean = false + + override fun supportsOrderByUnrelated(): Boolean = true + + override fun supportsOuterJoins(): Boolean = true + + override fun supportsPositionedDelete(): Boolean = false + + override fun supportsPositionedUpdate(): Boolean = false + + override fun supportsSelectForUpdate(): Boolean = false + + override fun supportsStoredProcedures(): Boolean = false + + override fun supportsSubqueriesInComparisons(): Boolean = true + + override fun supportsSubqueriesInExists(): Boolean = true + + override fun supportsSubqueriesInIns(): Boolean = true + + override fun supportsSubqueriesInQuantifieds(): Boolean = true + + override fun supportsTableCorrelationNames(): Boolean = false + + override fun supportsTransactions(): Boolean = true + + override fun supportsUnion(): Boolean = true + + override fun supportsUnionAll(): Boolean = true + + override fun getDefaultTransactionIsolation(): Int = Connection.TRANSACTION_SERIALIZABLE + + override fun supportsTransactionIsolationLevel(level: Int): Boolean = Connection.TRANSACTION_SERIALIZABLE == level + + override fun supportsGetGeneratedKeys(): Boolean = false + + override fun supportsResultSetType(type: Int): Boolean = when (type) { + ResultSet.TYPE_FORWARD_ONLY -> true + else -> false + } + + override fun supportsResultSetConcurrency( + type: Int, + concurrency: Int + ): Boolean = ResultSet.CONCUR_READ_ONLY == concurrency + + override fun supportsResultSetHoldability(holdability: Int): Boolean = ResultSet.CLOSE_CURSORS_AT_COMMIT == holdability + + override fun storesLowerCaseIdentifiers(): Boolean = false + + override fun storesLowerCaseQuotedIdentifiers(): Boolean = false + + override fun storesMixedCaseIdentifiers(): Boolean = true + + override fun storesMixedCaseQuotedIdentifiers(): Boolean = true + + override fun storesUpperCaseIdentifiers(): Boolean = false + + override fun storesUpperCaseQuotedIdentifiers(): Boolean = false + + override fun nullsAreSortedAtEnd(): Boolean = false + + override fun nullsAreSortedAtStart(): Boolean = true + + override fun nullsAreSortedHigh(): Boolean = false + + override fun nullsAreSortedLow(): Boolean = false + + override fun allProceduresAreCallable(): Boolean = false + + override fun allTablesAreSelectable(): Boolean = true + + override fun getURL(): String = connectionURL.toString() + + override fun getUserName(): String = "" + + override fun isReadOnly(): Boolean = connection.isReadOnly + + override fun usesLocalFiles(): Boolean = true + + override fun usesLocalFilePerTable(): Boolean = true + + override fun doesMaxRowSizeIncludeBlobs(): Boolean = true + + override fun locatorsUpdateCopy(): Boolean = false + + private fun executeMetadataQuery(sql: String): ResultSet = JdbcResultSet( + database.query(sql, emptyArray()), + null + ) + + @Suppress("Detekt.StringLiteralDuplication") + private fun mapSQLiteTypeToJDBCType(sqliteType: String): Int = when (sqliteType.uppercase()) { + "INTEGER", "INT", "SMALLINT", "MEDIUMINT", "BIGINT" -> Types.INTEGER + "REAL", "DOUBLE", "DOUBLE PRECISION", "FLOAT" -> Types.REAL + "NUMERIC", "DECIMAL" -> Types.NUMERIC + "TEXT", "CLOB", "VARCHAR", "CHAR", "CHARACTER" -> Types.VARCHAR + "BLOB" -> Types.BLOB + "NULL" -> Types.NULL + else -> Types.VARCHAR + } + + @Suppress("Detekt.StringLiteralDuplication") + private fun mapSQLiteTypeToJDBCTypeName(sqliteType: String): String = when (sqliteType.uppercase()) { + "INTEGER", "INT", "SMALLINT", "MEDIUMINT", "BIGINT" -> "INTEGER" + "REAL", "DOUBLE", "DOUBLE PRECISION", "FLOAT" -> "REAL" + "NUMERIC", "DECIMAL" -> "NUMERIC" + "TEXT", "CLOB" -> "TEXT" + "VARCHAR", "CHAR", "CHARACTER" -> sqliteType.uppercase() + "BLOB" -> "BLOB" + "NULL" -> "NULL" + else -> sqliteType.uppercase() + } + + @Suppress("Detekt.MagicNumber") + private fun getColumnSizeForType(sqliteType: String): Int = when (sqliteType.uppercase()) { + "INTEGER", "INT" -> 10 + "SMALLINT" -> 5 + "MEDIUMINT" -> 7 + "BIGINT" -> 19 + "REAL", "FLOAT" -> 15 + "DOUBLE", "DOUBLE PRECISION" -> 15 + "NUMERIC", "DECIMAL" -> 10 + "TEXT", "CLOB", "VARCHAR" -> Int.MAX_VALUE + "CHAR", "CHARACTER" -> 1 + "BLOB" -> Int.MAX_VALUE + else -> 255 + } + + override fun getTables( + catalog: String?, + schemaPattern: String?, + tableNamePattern: String?, + types: Array? + ): ResultSet { + val namePattern = tableNamePattern?.replace("%", "*") ?: "*" + val whereClause = if (tableNamePattern != null) { + "AND name GLOB '$namePattern'" + } else { + "" + } + val sql = """ + SELECT + NULL as TABLE_CAT, + NULL as TABLE_SCHEM, + name as TABLE_NAME, + CASE + WHEN type = 'table' THEN 'TABLE' + WHEN type = 'view' THEN 'VIEW' + ELSE UPPER(type) + END as TABLE_TYPE, + '' as REMARKS, + NULL as TYPE_CAT, + NULL as TYPE_SCHEM, + NULL as TYPE_NAME, + NULL as SELF_REFERENCING_COL_NAME, + NULL as REF_GENERATION + FROM sqlite_master + WHERE type IN ('table', 'view') $whereClause + AND name NOT LIKE 'sqlite_%' + ORDER BY TABLE_TYPE, TABLE_NAME + """.trimIndent() + return executeMetadataQuery(sql) + } + + @Suppress("Detekt.CognitiveComplexMethod", "Detekt.LongMethod", "Detekt.NestedBlockDepth") + override fun getColumns( + catalog: String?, + schemaPattern: String?, + tableNamePattern: String?, + columnNamePattern: String? + ): ResultSet { + val tablesResult = getTables(catalog, schemaPattern, tableNamePattern, arrayOf("TABLE", "VIEW")) + val columnRows = mutableListOf() + tablesResult.use { tablesResult -> + while (tablesResult.next()) { + val tableName = tablesResult.getString("TABLE_NAME") + val pragmaSql = "PRAGMA table_info('$tableName')" + val pragmaResult = executeMetadataQuery(pragmaSql) + pragmaResult.use { pragmaResult -> + while (pragmaResult.next()) { + val columnName = pragmaResult.getString("name") + val dataType = pragmaResult.getString("type") + val notNull = pragmaResult.getInt("notnull") + val defaultValue = pragmaResult.getString("dflt_value") + val primaryKey = pragmaResult.getInt("pk") + if (columnNamePattern != null && !columnName.matches( + columnNamePattern.replace("%", ".*").toRegex() + ) + ) { + continue + } + val sqlType = mapSQLiteTypeToJDBCType(dataType) + val typeName = mapSQLiteTypeToJDBCTypeName(dataType) + val columnSize = getColumnSizeForType(dataType) + columnRows.add(""" + SELECT + NULL as TABLE_CAT, + NULL as TABLE_SCHEM, + '$tableName' as TABLE_NAME, + '$columnName' as COLUMN_NAME, + $sqlType as DATA_TYPE, + '$typeName' as TYPE_NAME, + $columnSize as COLUMN_SIZE, + NULL as BUFFER_LENGTH, + NULL as DECIMAL_DIGITS, + 10 as NUM_PREC_RADIX, + ${if (notNull == 1) { "0" } else { "1" } } as NULLABLE, + '' as REMARKS, + ${if (defaultValue != null) { "'$defaultValue'" } else { "NULL" } } as COLUMN_DEF, + NULL as SQL_DATA_TYPE, + NULL as SQL_DATETIME_SUB, + NULL as CHAR_OCTET_LENGTH, + ${pragmaResult.row + 1} as ORDINAL_POSITION, + '${if (notNull == 1) { "NO" } else { "YES" } }' as IS_NULLABLE, + NULL as SCOPE_CATALOG, + NULL as SCOPE_SCHEMA, + NULL as SCOPE_TABLE, + NULL as SOURCE_DATA_TYPE, + '${if (primaryKey > 0) { "YES" } else { "NO" } }' as IS_AUTOINCREMENT, + '${if (primaryKey > 0) { "YES" } else { "NO" } }' as IS_GENERATEDCOLUMN + """.trimIndent()) + } + } + } + } + val unionSql = if (columnRows.isNotEmpty()) { + columnRows.joinToString("\nUNION ALL\n") + "\nORDER BY TABLE_NAME, ORDINAL_POSITION" + } else { + """ + SELECT + NULL as TABLE_CAT, NULL as TABLE_SCHEM, NULL as TABLE_NAME, + NULL as COLUMN_NAME, NULL as DATA_TYPE, NULL as TYPE_NAME, + NULL as COLUMN_SIZE, NULL as BUFFER_LENGTH, NULL as DECIMAL_DIGITS, + NULL as NUM_PREC_RADIX, NULL as NULLABLE, NULL as REMARKS, + NULL as COLUMN_DEF, NULL as SQL_DATA_TYPE, NULL as SQL_DATETIME_SUB, + NULL as CHAR_OCTET_LENGTH, NULL as ORDINAL_POSITION, NULL as IS_NULLABLE, + NULL as SCOPE_CATALOG, NULL as SCOPE_SCHEMA, NULL as SCOPE_TABLE, + NULL as SOURCE_DATA_TYPE, NULL as IS_AUTOINCREMENT, NULL as IS_GENERATEDCOLUMN + WHERE 1 = 0 + """.trimIndent() + } + return executeMetadataQuery(unionSql) + } + + override fun getPrimaryKeys(catalog: String?, schema: String?, table: String): ResultSet { + val sql = "PRAGMA table_info('$table')" + val pragmaResult = executeMetadataQuery(sql) + val pkRows = mutableListOf() + pragmaResult.use { pragmaResult -> + while (pragmaResult.next()) { + val primaryKey = pragmaResult.getInt("pk") + if (primaryKey > 0) { + pkRows.add(""" + SELECT + NULL as TABLE_CAT, + NULL as TABLE_SCHEM, + '$table' as TABLE_NAME, + '${pragmaResult.getString("name")}' as COLUMN_NAME, + $primaryKey as KEY_SEQ, + 'PRIMARY' as PK_NAME + """.trimIndent()) + } + } + } + val unionSql = if (pkRows.isNotEmpty()) { + pkRows.joinToString("\nUNION ALL\n") + "\nORDER BY KEY_SEQ" + } else { + """ + SELECT + NULL as TABLE_CAT, NULL as TABLE_SCHEM, NULL as TABLE_NAME, + NULL as COLUMN_NAME, NULL as KEY_SEQ, NULL as PK_NAME + WHERE 1 = 0 + """.trimIndent() + } + return executeMetadataQuery(unionSql) + } + + override fun getIndexInfo( + catalog: String?, + schema: String?, + table: String, + unique: Boolean, + approximate: Boolean + ): ResultSet = executeMetadataQuery(""" + SELECT + NULL as TABLE_CAT, + NULL as TABLE_SCHEM, + '$table' as TABLE_NAME, + CASE WHEN "unique" = 1 THEN 0 ELSE 1 END as NON_UNIQUE, + NULL as INDEX_QUALIFIER, + name as INDEX_NAME, + 3 as TYPE, + 0 as ORDINAL_POSITION, + NULL as COLUMN_NAME, + NULL as ASC_OR_DESC, + 0 as CARDINALITY, + 0 as PAGES, + NULL as FILTER_CONDITION + FROM sqlite_master + WHERE type = 'index' AND tbl_name = '$table' + ${if (unique) { "AND \"unique\" = 1" } else { "" } } + """.trimIndent()) + + override fun getProcedures( + catalog: String?, + schemaPattern: String?, + procedureNamePattern: String? + ): ResultSet = throw SQLFeatureNotSupportedException("getProcedures not implemented") + + override fun getProcedureColumns( + catalog: String?, + schemaPattern: String?, + procedureNamePattern: String?, + columnNamePattern: String? + ): ResultSet = throw SQLFeatureNotSupportedException("getProcedureColumns not implemented") + + override fun getSchemas(): ResultSet = getSchemas(null, null) + override fun getSchemas( + catalog: String?, + schemaPattern: String? + ): ResultSet = throw SQLFeatureNotSupportedException("getSchemas not implemented") + + override fun getCatalogs(): ResultSet = throw SQLFeatureNotSupportedException("getCatalogs not implemented") + + override fun getTableTypes(): ResultSet = executeMetadataQuery(""" + SELECT 'TABLE' as TABLE_TYPE + UNION ALL + SELECT 'VIEW' as TABLE_TYPE + ORDER BY TABLE_TYPE + """.trimIndent()) + + @Suppress("Detekt.LongMethod") + override fun getTypeInfo(): ResultSet = executeMetadataQuery(""" + SELECT + 'INTEGER' as TYPE_NAME, + ${Types.INTEGER} as DATA_TYPE, + 10 as PRECISION, + NULL as LITERAL_PREFIX, + NULL as LITERAL_SUFFIX, + NULL as CREATE_PARAMS, + 1 as NULLABLE, + 1 as CASE_SENSITIVE, + 2 as SEARCHABLE, + 0 as UNSIGNED_ATTRIBUTE, + 0 as FIXED_PREC_SCALE, + 0 as AUTO_INCREMENT, + 'INTEGER' as LOCAL_TYPE_NAME, + 0 as MINIMUM_SCALE, + 0 as MAXIMUM_SCALE, + NULL as SQL_DATA_TYPE, + NULL as SQL_DATETIME_SUB, + 10 as NUM_PREC_RADIX + UNION ALL + SELECT + 'REAL' as TYPE_NAME, + ${Types.REAL} as DATA_TYPE, + 15 as PRECISION, + NULL as LITERAL_PREFIX, + NULL as LITERAL_SUFFIX, + NULL as CREATE_PARAMS, + 1 as NULLABLE, + 0 as CASE_SENSITIVE, + 2 as SEARCHABLE, + 0 as UNSIGNED_ATTRIBUTE, + 0 as FIXED_PREC_SCALE, + 0 as AUTO_INCREMENT, + 'REAL' as LOCAL_TYPE_NAME, + 0 as MINIMUM_SCALE, + 15 as MAXIMUM_SCALE, + NULL as SQL_DATA_TYPE, + NULL as SQL_DATETIME_SUB, + 10 as NUM_PREC_RADIX + UNION ALL + SELECT + 'TEXT' as TYPE_NAME, + ${Types.VARCHAR} as DATA_TYPE, + ${Int.MAX_VALUE} as PRECISION, + '''' as LITERAL_PREFIX, + '''' as LITERAL_SUFFIX, + NULL as CREATE_PARAMS, + 1 as NULLABLE, + 1 as CASE_SENSITIVE, + 3 as SEARCHABLE, + 0 as UNSIGNED_ATTRIBUTE, + 0 as FIXED_PREC_SCALE, + 0 as AUTO_INCREMENT, + 'TEXT' as LOCAL_TYPE_NAME, + 0 as MINIMUM_SCALE, + 0 as MAXIMUM_SCALE, + NULL as SQL_DATA_TYPE, + NULL as SQL_DATETIME_SUB, + NULL as NUM_PREC_RADIX + UNION ALL + SELECT + 'BLOB' as TYPE_NAME, + ${Types.BLOB} as DATA_TYPE, + ${Int.MAX_VALUE} as PRECISION, + NULL as LITERAL_PREFIX, + NULL as LITERAL_SUFFIX, + NULL as CREATE_PARAMS, + 1 as NULLABLE, + 0 as CASE_SENSITIVE, + 0 as SEARCHABLE, + 0 as UNSIGNED_ATTRIBUTE, + 0 as FIXED_PREC_SCALE, + 0 as AUTO_INCREMENT, + 'BLOB' as LOCAL_TYPE_NAME, + 0 as MINIMUM_SCALE, + 0 as MAXIMUM_SCALE, + NULL as SQL_DATA_TYPE, + NULL as SQL_DATETIME_SUB, + NULL as NUM_PREC_RADIX + UNION ALL + SELECT + 'NUMERIC' as TYPE_NAME, + ${Types.NUMERIC} as DATA_TYPE, + 10 as PRECISION, + NULL as LITERAL_PREFIX, + NULL as LITERAL_SUFFIX, + 'precision,scale' as CREATE_PARAMS, + 1 as NULLABLE, + 0 as CASE_SENSITIVE, + 2 as SEARCHABLE, + 0 as UNSIGNED_ATTRIBUTE, + 0 as FIXED_PREC_SCALE, + 0 as AUTO_INCREMENT, + 'NUMERIC' as LOCAL_TYPE_NAME, + 0 as MINIMUM_SCALE, + 10 as MAXIMUM_SCALE, + NULL as SQL_DATA_TYPE, + NULL as SQL_DATETIME_SUB, + 10 as NUM_PREC_RADIX + ORDER BY DATA_TYPE + """.trimIndent()) + + override fun getUDTs( + catalog: String?, + schemaPattern: String?, + typeNamePattern: String?, + types: IntArray? + ): ResultSet = throw SQLFeatureNotSupportedException("getUDTs not implemented") + + override fun getSuperTypes( + catalog: String?, + schemaPattern: String?, + typeNamePattern: String? + ): ResultSet = throw SQLFeatureNotSupportedException("getSuperTypes not implemented") + + override fun getSuperTables( + catalog: String?, + schemaPattern: String?, + tableNamePattern: String? + ): ResultSet = throw SQLFeatureNotSupportedException("getSuperTables not implemented") + + override fun getAttributes( + catalog: String?, + schemaPattern: String?, + typeNamePattern: String?, + attributeNamePattern: String? + ): ResultSet = throw SQLFeatureNotSupportedException("getAttributes not implemented") + + override fun getFunctions( + catalog: String?, + schemaPattern: String?, + functionNamePattern: String? + ): ResultSet = throw SQLFeatureNotSupportedException("getFunctions not implemented") + + override fun getFunctionColumns( + catalog: String?, + schemaPattern: String?, + functionNamePattern: String?, + columnNamePattern: String? + ): ResultSet = throw SQLFeatureNotSupportedException("getFunctionColumns not implemented") + + override fun getPseudoColumns( + catalog: String?, + schemaPattern: String?, + tableNamePattern: String?, + columnNamePattern: String? + ): ResultSet = throw SQLFeatureNotSupportedException("getPseudoColumns not implemented") + + override fun generatedKeyAlwaysReturned(): Boolean = false + + override fun dataDefinitionCausesTransactionCommit(): Boolean = false + + override fun dataDefinitionIgnoredInTransactions(): Boolean = false + + override fun deletesAreDetected(type: Int): Boolean = false + + override fun insertsAreDetected(type: Int): Boolean = false + + override fun updatesAreDetected(type: Int): Boolean = false + + override fun othersDeletesAreVisible(type: Int): Boolean = false + + override fun othersInsertsAreVisible(type: Int): Boolean = false + + override fun othersUpdatesAreVisible(type: Int): Boolean = false + + override fun ownDeletesAreVisible(type: Int): Boolean = false + + override fun ownInsertsAreVisible(type: Int): Boolean = false + + override fun ownUpdatesAreVisible(type: Int): Boolean = false + + override fun supportsNamedParameters(): Boolean = false + + override fun supportsStatementPooling(): Boolean = false + + override fun supportsSavepoints(): Boolean = true + + override fun supportsStoredFunctionsUsingCallSyntax(): Boolean = false + + override fun supportsCoreSQLGrammar(): Boolean = true + + override fun supportsSchemasInDataManipulation(): Boolean = false + + override fun supportsSchemasInProcedureCalls(): Boolean = false + + override fun supportsSchemasInTableDefinitions(): Boolean = false + + override fun supportsSchemasInIndexDefinitions(): Boolean = false + + override fun supportsSchemasInPrivilegeDefinitions(): Boolean = false + + override fun nullPlusNonNullIsNull(): Boolean = true + + override fun autoCommitFailureClosesAllResultSets(): Boolean = false + + override fun getRowIdLifetime(): RowIdLifetime = RowIdLifetime.ROWID_UNSUPPORTED + + override fun getResultSetHoldability(): Int = ResultSet.HOLD_CURSORS_OVER_COMMIT + + override fun getColumnPrivileges( + catalog: String?, + schema: String?, + table: String?, + columnNamePattern: String? + ): ResultSet = throw SQLFeatureNotSupportedException("SQLite does not support column privileges") + + override fun getTablePrivileges( + catalog: String?, + schemaPattern: String?, + tableNamePattern: String? + ): ResultSet = throw SQLFeatureNotSupportedException("SQLite does not support table privileges") + + override fun getBestRowIdentifier( + catalog: String?, + schema: String?, + table: String?, + scope: Int, + nullable: Boolean + ): ResultSet = throw SQLFeatureNotSupportedException("SQLite does not support best row identifier metadata") + + override fun getVersionColumns( + catalog: String?, + schema: String?, + table: String? + ): ResultSet = throw SQLFeatureNotSupportedException("SQLite does not support version columns") + + override fun getImportedKeys( + catalog: String?, + schema: String?, + table: String? + ): ResultSet = throw SQLFeatureNotSupportedException("Use PRAGMA foreign_key_list to get foreign key information") + + override fun getExportedKeys( + catalog: String?, + schema: String?, + table: String? + ): ResultSet { + throw SQLFeatureNotSupportedException("SQLite does not support exported key metadata") + } + + override fun getCrossReference( + parentCatalog: String?, + parentSchema: String?, + parentTable: String?, + foreignCatalog: String?, + foreignSchema: String?, + foreignTable: String? + ): ResultSet = throw SQLFeatureNotSupportedException("SQLite does not support cross reference metadata") + + override fun getClientInfoProperties(): ResultSet = throw SQLFeatureNotSupportedException( + "SQLite does not support client info properties" + ) + + override fun unwrap(iface: Class): T = if (iface.isAssignableFrom(this::class.java)) { + @Suppress("UNCHECKED_CAST") + this as T + } else if (iface.isAssignableFrom(SQLDatabase::class.java)) { + @Suppress("UNCHECKED_CAST") + database as T + } else { + throw SQLException("Cannot unwrap to ${iface.name}") + } + + override fun isWrapperFor( + iface: Class<*> + ): Boolean = iface.isAssignableFrom(this::class.java) || iface.isAssignableFrom(SQLDatabase::class.java) +} diff --git a/selekt-jdbc/src/main/kotlin/com/bloomberg/selekt/jdbc/result/JdbcResultSet.kt b/selekt-jdbc/src/main/kotlin/com/bloomberg/selekt/jdbc/result/JdbcResultSet.kt new file mode 100644 index 0000000000..3dbf53153b --- /dev/null +++ b/selekt-jdbc/src/main/kotlin/com/bloomberg/selekt/jdbc/result/JdbcResultSet.kt @@ -0,0 +1,1027 @@ +/* + * Copyright 2026 Bloomberg Finance L.P. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.bloomberg.selekt.jdbc.result + +import com.bloomberg.selekt.ColumnType +import com.bloomberg.selekt.ICursor +import com.bloomberg.selekt.jdbc.exception.SQLExceptionMapper +import com.bloomberg.selekt.jdbc.lob.JdbcClob +import com.bloomberg.selekt.jdbc.util.TypeMapping +import java.io.InputStream +import java.io.Reader +import java.math.BigDecimal +import java.net.URL +import java.sql.Blob +import java.sql.Clob +import java.sql.Date +import java.sql.NClob +import java.sql.Ref +import java.sql.ResultSet +import java.sql.ResultSetMetaData +import java.sql.RowId +import java.sql.SQLException +import java.sql.SQLFeatureNotSupportedException +import java.sql.SQLXML +import java.sql.SQLWarning +import java.sql.Statement +import java.sql.Time +import java.sql.Timestamp +import java.sql.Types +import java.time.LocalDate +import java.time.LocalDateTime +import java.time.LocalTime +import java.util.Calendar +import javax.annotation.concurrent.NotThreadSafe + +private const val READ_ONLY_ERROR = "ResultSet is read-only" + +@NotThreadSafe +@Suppress("MethodOverloading", "LargeClass", "TooGenericExceptionCaught") +internal class JdbcResultSet( + private val cursor: ICursor, + private val statement: Statement?, + private val resultSetType: Int = ResultSet.TYPE_FORWARD_ONLY, + private val resultSetConcurrency: Int = ResultSet.CONCUR_READ_ONLY, + private val resultSetHoldability: Int = ResultSet.CLOSE_CURSORS_AT_COMMIT +) : ResultSet { + private var wasNull = false + private val metadata by lazy { JdbcResultSetMetaData(cursor) } + private var fetchSize = 0 + + override fun next(): Boolean { + checkClosed() + val result = cursor.moveToNext() + return result + } + + override fun close() { + if (!cursor.isClosed()) { + cursor.close() + } + } + + override fun wasNull(): Boolean = wasNull + + override fun getString(columnIndex: Int): String? { + checkClosed() + validateColumnIndex(columnIndex) + return runCatching { + if (cursor.isNull(columnIndex - 1)) { + wasNull = true + null + } else { + wasNull = false + cursor.getString(columnIndex - 1) + } + }.getOrElse { e -> + throw SQLExceptionMapper.mapException( + "Error getting string from column $columnIndex: ${e.message}", + -1, + -1, + e + ) + } + } + + override fun getString(columnLabel: String): String? = getString(findColumn(columnLabel)) + + override fun getBoolean(columnIndex: Int): Boolean { + checkClosed() + validateColumnIndex(columnIndex) + return runCatching { + if (cursor.isNull(columnIndex - 1)) { + wasNull = true + false + } else { + wasNull = false + runCatching { + cursor.getLong(columnIndex - 1) != 0L + }.getOrElse { _ -> + cursor.getString(columnIndex - 1)?.toBoolean() ?: false + } + } + }.getOrElse { e -> + throw SQLExceptionMapper.mapException( + "Error getting boolean from column $columnIndex: ${e.message}", + -1, + -1, + e + ) + } + } + + override fun getBoolean(columnLabel: String): Boolean = getBoolean(findColumn(columnLabel)) + + override fun getByte(columnIndex: Int): Byte { + checkClosed() + validateColumnIndex(columnIndex) + return runCatching { + if (cursor.isNull(columnIndex - 1)) { + wasNull = true + 0 + } else { + wasNull = false + cursor.getInt(columnIndex - 1).toByte() + } + }.getOrElse { e -> + throw SQLExceptionMapper.mapException( + "Error getting byte from column $columnIndex: ${e.message}", + -1, + -1, + e + ) + } + } + + override fun getByte(columnLabel: String): Byte = getByte(findColumn(columnLabel)) + + override fun getShort(columnIndex: Int): Short { + checkClosed() + validateColumnIndex(columnIndex) + return try { + if (cursor.isNull(columnIndex - 1)) { + wasNull = true + 0 + } else { + wasNull = false + cursor.getShort(columnIndex - 1) + } + } catch (e: SQLException) { + throw SQLExceptionMapper.mapException(e) + } catch (e: RuntimeException) { + throw SQLExceptionMapper.mapException( + "Error getting short from column $columnIndex: ${e.message}", + -1, + -1, + e + ) + } + } + + override fun getShort(columnLabel: String): Short = getShort(findColumn(columnLabel)) + + override fun getInt(columnIndex: Int): Int { + checkClosed() + validateColumnIndex(columnIndex) + return try { + if (cursor.isNull(columnIndex - 1)) { + wasNull = true + 0 + } else { + wasNull = false + when (cursor.type(columnIndex - 1)) { + ColumnType.INTEGER -> cursor.getInt(columnIndex - 1) + ColumnType.STRING -> { + val stringValue = cursor.getString(columnIndex - 1) + TypeMapping.convertFromSQLite(stringValue, Types.INTEGER) as Int + } + ColumnType.FLOAT -> cursor.getDouble(columnIndex - 1).toInt() + ColumnType.NULL -> 0 + ColumnType.BLOB -> 0 // Cannot convert blob to int + } + } + } catch (e: SQLException) { + throw SQLExceptionMapper.mapException(e) + } catch (e: RuntimeException) { + throw SQLExceptionMapper.mapException( + "Error getting int from column $columnIndex: ${e.message}", + -1, + -1, + e + ) + } + } + + override fun getInt(columnLabel: String): Int = getInt(findColumn(columnLabel)) + + override fun getLong(columnIndex: Int): Long { + checkClosed() + validateColumnIndex(columnIndex) + return runCatching { + if (cursor.isNull(columnIndex - 1)) { + wasNull = true + 0L + } else { + wasNull = false + cursor.getLong(columnIndex - 1) + } + }.getOrElse { e -> + throw SQLExceptionMapper.mapException( + "Error getting long from column $columnIndex: ${e.message}", + -1, + -1, + e + ) + } + } + + override fun getLong(columnLabel: String): Long = getLong(findColumn(columnLabel)) + + override fun getFloat(columnIndex: Int): Float { + checkClosed() + validateColumnIndex(columnIndex) + return runCatching { + if (cursor.isNull(columnIndex - 1)) { + wasNull = true + 0f + } else { + wasNull = false + cursor.getFloat(columnIndex - 1) + } + }.getOrElse { e -> + throw SQLExceptionMapper.mapException( + "Error getting float from column $columnIndex: ${e.message}", + -1, + -1, + e + ) + } + } + + override fun getFloat(columnLabel: String): Float = getFloat(findColumn(columnLabel)) + + override fun getDouble(columnIndex: Int): Double { + checkClosed() + validateColumnIndex(columnIndex) + return runCatching { + if (cursor.isNull(columnIndex - 1)) { + wasNull = true + 0.0 + } else { + wasNull = false + val columnType = cursor.type(columnIndex - 1) + + when (columnType) { + ColumnType.FLOAT -> cursor.getDouble(columnIndex - 1) + ColumnType.INTEGER -> cursor.getInt(columnIndex - 1).toDouble() + ColumnType.STRING -> { + val stringValue = cursor.getString(columnIndex - 1) + TypeMapping.convertFromSQLite(stringValue, Types.DOUBLE) as Double + } + ColumnType.NULL -> 0.0 + ColumnType.BLOB -> 0.0 // Cannot convert blob to double + } + } + }.getOrElse { e -> + throw SQLExceptionMapper.mapException( + "Error getting double from column $columnIndex: ${e.message}", + -1, + -1, + e + ) + } + } + + override fun getDouble(columnLabel: String): Double = getDouble(findColumn(columnLabel)) + + override fun getBigDecimal(columnIndex: Int): BigDecimal? { + checkClosed() + validateColumnIndex(columnIndex) + return runCatching { + if (cursor.isNull(columnIndex - 1)) { + wasNull = true + null + } else { + wasNull = false + val value = cursor.getDouble(columnIndex - 1) + BigDecimal.valueOf(value) + } + }.getOrElse { e -> + throw SQLExceptionMapper.mapException( + "Error getting BigDecimal from column $columnIndex: ${e.message}", + -1, + -1, + e + ) + } + } + + @Deprecated("Deprecated in Java", ReplaceWith("getBigDecimal(columnIndex)")) + @Suppress("DEPRECATION") + override fun getBigDecimal(columnIndex: Int, scale: Int): BigDecimal? = getBigDecimal(columnIndex)?.setScale( + scale, BigDecimal.ROUND_HALF_UP) + + override fun getBigDecimal(columnLabel: String): BigDecimal? = getBigDecimal(findColumn(columnLabel)) + + @Deprecated("Deprecated in Java", ReplaceWith("getBigDecimal(columnLabel)")) + @Suppress("DEPRECATION") + override fun getBigDecimal( + columnLabel: String, + scale: Int + ): BigDecimal? = getBigDecimal(columnLabel)?.setScale(scale, BigDecimal.ROUND_HALF_UP) + + override fun getBytes(columnIndex: Int): ByteArray? { + checkClosed() + validateColumnIndex(columnIndex) + return runCatching { + if (cursor.isNull(columnIndex - 1)) { + wasNull = true + null + } else { + wasNull = false + cursor.getBlob(columnIndex - 1) + } + }.getOrElse { e -> + throw SQLExceptionMapper.mapException( + "Error getting bytes from column $columnIndex: ${e.message}", + -1, + -1, + e + ) + } + } + + override fun getBytes(columnLabel: String): ByteArray? = getBytes(findColumn(columnLabel)) + + override fun getDate(columnIndex: Int): Date? { + val dateString = getString(columnIndex) + return if (wasNull) { + null + } else { + runCatching { + dateString?.let { + TypeMapping.convertFromSQLite(it, Types.DATE) as? Date + } + }.getOrElse { e -> + throw SQLExceptionMapper.mapException( + "Error parsing date from column $columnIndex: ${e.message}", + -1, + -1, + e + ) + } + } + } + + override fun getDate(columnLabel: String): Date? = getDate(findColumn(columnLabel)) + + override fun getDate(columnIndex: Int, cal: Calendar?): Date? = getDate(columnIndex) + + override fun getDate(columnLabel: String, cal: Calendar?): Date? = getDate(columnLabel) + + override fun getTime(columnIndex: Int): Time? { + val timeString = getString(columnIndex) + return if (wasNull) { + null + } else { + runCatching { + timeString?.let { + TypeMapping.convertFromSQLite(it, Types.TIME) as? Time + } + }.getOrElse { e -> + throw SQLExceptionMapper.mapException( + "Error parsing time from column $columnIndex: ${e.message}", + -1, + -1, + e + ) + } + } + } + + override fun getTime(columnLabel: String): Time? = getTime(findColumn(columnLabel)) + + override fun getTime(columnIndex: Int, cal: Calendar?): Time? = getTime(columnIndex) + + override fun getTime(columnLabel: String, cal: Calendar?): Time? = getTime(columnLabel) + + override fun getTimestamp(columnIndex: Int): Timestamp? { + val timestampString = getString(columnIndex) + return if (wasNull) { + null + } else { + runCatching { + timestampString?.let { + TypeMapping.convertFromSQLite(it, Types.TIMESTAMP) as? Timestamp + } + }.getOrElse { e -> + throw SQLExceptionMapper.mapException( + "Error parsing timestamp from column $columnIndex: ${e.message}", + -1, + -1, + e + ) + } + } + } + + override fun getTimestamp(columnLabel: String): Timestamp? = getTimestamp(findColumn(columnLabel)) + + override fun getTimestamp(columnIndex: Int, cal: Calendar?): Timestamp? = getTimestamp(columnIndex) + + override fun getTimestamp(columnLabel: String, cal: Calendar?): Timestamp? = getTimestamp(columnLabel) + + override fun findColumn(columnLabel: String): Int { + checkClosed() + when (val index = cursor.columnIndex(columnLabel)) { + -1 -> throw SQLException("Column '$columnLabel' not found") + else -> return index + 1 + } + } + + override fun isBeforeFirst(): Boolean { + checkClosed() + return cursor.isBeforeFirst() + } + + override fun isAfterLast(): Boolean { + checkClosed() + return cursor.isAfterLast() + } + + override fun isFirst(): Boolean { + checkClosed() + return cursor.isFirst() + } + + override fun isLast(): Boolean { + checkClosed() + return cursor.isLast() + } + + override fun beforeFirst() { + checkClosed() + checkScrollable() + cursor.moveToPosition(-1) + } + + override fun afterLast() { + checkClosed() + checkScrollable() + cursor.moveToPosition(cursor.count) + } + + override fun first(): Boolean { + checkClosed() + checkScrollable() + return cursor.moveToFirst() + } + + override fun last(): Boolean { + checkClosed() + checkScrollable() + return cursor.moveToLast() + } + + override fun getRow(): Int { + checkClosed() + return if (cursor.isBeforeFirst() || cursor.isAfterLast()) { 0 } else { cursor.position() + 1 } + } + + override fun absolute(row: Int): Boolean { + checkClosed() + checkScrollable() + return cursor.moveToPosition(row - 1) + } + + override fun relative(rows: Int): Boolean { + checkClosed() + checkScrollable() + return cursor.move(rows) + } + + override fun previous(): Boolean { + checkClosed() + checkScrollable() + return cursor.moveToPrevious() + } + + override fun getMetaData(): ResultSetMetaData = metadata + + override fun getObject(columnIndex: Int): Any? { + checkClosed() + validateColumnIndex(columnIndex) + return runCatching { + if (cursor.isNull(columnIndex - 1)) { + wasNull = true + null + } else { + wasNull = false + val columnType = cursor.type(columnIndex - 1) + when (columnType) { + ColumnType.INTEGER -> cursor.getLong(columnIndex - 1) + ColumnType.FLOAT -> cursor.getDouble(columnIndex - 1) + ColumnType.STRING -> cursor.getString(columnIndex - 1) + ColumnType.BLOB -> cursor.getBlob(columnIndex - 1) + ColumnType.NULL -> null + } + } + }.getOrElse { e -> + throw SQLExceptionMapper.mapException( + "Error getting object from column $columnIndex: ${e.message}", + -1, + -1, + e + ) + } + } + + override fun getObject(columnLabel: String): Any? = getObject(findColumn(columnLabel)) + + override fun getObject(columnIndex: Int, type: Class): T? { + val value = getObject(columnIndex) + @Suppress("UNCHECKED_CAST") + return when { + value == null || wasNull -> null as T? + type.isInstance(value) -> value as T + type == String::class.java -> value.toString() as T + type == Int::class.java && value is Number -> value.toInt() as T + type == Long::class.java && value is Number -> value.toLong() as T + type == Double::class.java && value is Number -> value.toDouble() as T + type == Float::class.java && value is Number -> value.toFloat() as T + type == Boolean::class.java -> when (value) { + is Boolean -> value + is Number -> value.toLong() != 0L + is String -> value.equals("true", ignoreCase = true) || value == "1" + else -> false + } as T + type == LocalDate::class.java && value is String -> LocalDate.parse(value) as T + type == LocalTime::class.java && value is String -> LocalTime.parse(value) as T + type == LocalDateTime::class.java && value is String -> LocalDateTime.parse(value) as T + else -> throw SQLException("Cannot convert ${value.javaClass.name} to ${type.name}") + } + } + + override fun getObject(columnLabel: String, type: Class): T? = getObject(findColumn(columnLabel), type) + + override fun getObject(columnIndex: Int, map: MutableMap>?): Any? = getObject(columnIndex) + + override fun getObject(columnLabel: String, map: MutableMap>?): Any? = getObject(columnLabel) + + override fun getAsciiStream( + columnIndex: Int + ): InputStream? = getString(columnIndex)?.byteInputStream(Charsets.US_ASCII) + + override fun getAsciiStream(columnLabel: String): InputStream? = getAsciiStream(findColumn(columnLabel)) + + @Deprecated("Deprecated in Java", ReplaceWith("getCharacterStream(columnIndex)")) + override fun getUnicodeStream( + columnIndex: Int + ): InputStream? = getString(columnIndex)?.byteInputStream(Charsets.UTF_16) + + @Deprecated("Deprecated in Java", ReplaceWith("getCharacterStream(columnLabel)")) + @Suppress("DEPRECATION") + override fun getUnicodeStream(columnLabel: String): InputStream? = getUnicodeStream(findColumn(columnLabel)) + + override fun getBinaryStream(columnIndex: Int): InputStream? = getBytes(columnIndex)?.inputStream() + + override fun getBinaryStream(columnLabel: String): InputStream? = getBinaryStream(findColumn(columnLabel)) + + override fun getCharacterStream(columnIndex: Int): Reader? = getString(columnIndex)?.reader() + + override fun getCharacterStream(columnLabel: String): Reader? = getCharacterStream(findColumn(columnLabel)) + + override fun getWarnings(): SQLWarning? = null + + override fun clearWarnings() {} + + override fun getCursorName(): String = throw SQLFeatureNotSupportedException("Named cursors not supported") + + override fun getStatement(): Statement? = statement + + override fun getType(): Int = resultSetType + + override fun getConcurrency(): Int = resultSetConcurrency + + override fun getHoldability(): Int = resultSetHoldability + + override fun isClosed(): Boolean = cursor.isClosed() + + override fun setFetchDirection(direction: Int) { + if (direction != ResultSet.FETCH_FORWARD) { + throw SQLFeatureNotSupportedException("Only FETCH_FORWARD is supported") + } + } + + override fun getFetchDirection(): Int = ResultSet.FETCH_FORWARD + + override fun setFetchSize(rows: Int) { + checkClosed() + if (rows < 0) { + throw SQLException("Fetch size must be >= 0, got: $rows") + } + fetchSize = rows + } + + override fun getFetchSize(): Int = fetchSize + + override fun getNClob(columnIndex: Int): NClob = throw SQLFeatureNotSupportedException() + + override fun getNClob(columnLabel: String): NClob = throw SQLFeatureNotSupportedException() + + override fun getSQLXML(columnIndex: Int): SQLXML = throw SQLFeatureNotSupportedException() + + override fun getSQLXML(columnLabel: String): SQLXML = throw SQLFeatureNotSupportedException() + + override fun getURL(columnIndex: Int): URL = throw SQLFeatureNotSupportedException() + + override fun getURL(columnLabel: String): URL = throw SQLFeatureNotSupportedException() + + override fun getArray(columnIndex: Int): java.sql.Array = throw SQLFeatureNotSupportedException() + + override fun getArray(columnLabel: String): java.sql.Array = throw SQLFeatureNotSupportedException() + + override fun getBlob(columnIndex: Int): Blob = throw SQLFeatureNotSupportedException() + + override fun getBlob(columnLabel: String): Blob = throw SQLFeatureNotSupportedException() + + override fun getClob(columnIndex: Int): Clob? { + checkClosed() + val text = getString(columnIndex) + wasNull = text == null + return if (text != null) JdbcClob(text) else null + } + + override fun getClob(columnLabel: String): Clob? { + checkClosed() + val text = getString(columnLabel) + wasNull = text == null + return if (text != null) { + JdbcClob(text) + } else { + null + } + } + + override fun getRef(columnIndex: Int): Ref = throw SQLFeatureNotSupportedException() + + override fun getRef(columnLabel: String): Ref = throw SQLFeatureNotSupportedException() + + override fun getNString(columnIndex: Int): String? = getString(columnIndex) + + override fun getNString(columnLabel: String): String? = getString(columnLabel) + + override fun getNCharacterStream(columnIndex: Int): Reader? = getCharacterStream(columnIndex) + + override fun getNCharacterStream(columnLabel: String): Reader? = getCharacterStream(columnLabel) + + override fun getRowId(columnIndex: Int): RowId = throw SQLFeatureNotSupportedException() + + override fun getRowId(columnLabel: String): RowId = throw SQLFeatureNotSupportedException() + + override fun rowUpdated(): Boolean = false + + override fun rowInserted(): Boolean = false + + override fun rowDeleted(): Boolean = false + + override fun updateNull(columnIndex: Int) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateNull(columnLabel: String) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateBoolean(columnIndex: Int, x: Boolean) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateBoolean(columnLabel: String, x: Boolean) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateByte(columnIndex: Int, x: Byte) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateByte(columnLabel: String, x: Byte) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateShort(columnIndex: Int, x: Short) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateShort(columnLabel: String, x: Short) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateInt(columnIndex: Int, x: Int) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateInt(columnLabel: String, x: Int) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateLong(columnIndex: Int, x: Long) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateLong(columnLabel: String, x: Long) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateFloat(columnIndex: Int, x: Float) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateFloat(columnLabel: String, x: Float) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateDouble(columnIndex: Int, x: Double) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateDouble(columnLabel: String, x: Double) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateBigDecimal(columnIndex: Int, x: BigDecimal?) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateBigDecimal( + columnLabel: String, + x: BigDecimal? + ) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateString(columnIndex: Int, x: String?) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateString(columnLabel: String, x: String?) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateBytes(columnIndex: Int, x: ByteArray?) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateBytes(columnLabel: String, x: ByteArray?) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateDate(columnIndex: Int, x: Date?) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateDate(columnLabel: String, x: Date?) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateTime(columnIndex: Int, x: Time?) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateTime(columnLabel: String, x: Time?) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateTimestamp(columnIndex: Int, x: Timestamp?) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateTimestamp(columnLabel: String, x: Timestamp?) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateAsciiStream( + columnIndex: Int, + x: InputStream?, + length: Int + ) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateAsciiStream( + columnLabel: String, + x: InputStream?, + length: Int + ) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateAsciiStream( + columnIndex: Int, + x: InputStream?, + length: Long + ) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateAsciiStream( + columnLabel: String, + x: InputStream?, + length: Long + ) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateAsciiStream( + columnIndex: Int, + x: InputStream? + ) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateAsciiStream( + columnLabel: String, + x: InputStream? + ) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateBinaryStream( + columnIndex: Int, + x: InputStream?, + length: Int + ) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateBinaryStream( + columnLabel: String, + x: InputStream?, + length: Int + ) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateBinaryStream( + columnIndex: Int, + x: InputStream?, + length: Long + ) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateBinaryStream( + columnLabel: String, + x: InputStream?, + length: Long + ) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateBinaryStream( + columnIndex: Int, + x: InputStream? + ) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateBinaryStream( + columnLabel: String, + x: InputStream? + ) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateCharacterStream( + columnIndex: Int, + x: Reader?, + length: Int + ) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateCharacterStream( + columnLabel: String, + reader: Reader?, + length: Int + ) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateCharacterStream( + columnIndex: Int, + x: Reader?, + length: Long + ) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateCharacterStream( + columnLabel: String, + reader: Reader?, + length: Long + ) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateCharacterStream( + columnIndex: Int, + reader: Reader? + ) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateCharacterStream( + columnLabel: String, + reader: Reader? + ) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateObject( + columnIndex: Int, + x: Any?, + scaleOrLength: Int + ) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateObject( + columnIndex: Int, + x: Any? + ) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateObject( + columnLabel: String, + x: Any?, + scaleOrLength: Int + ) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateObject( + columnLabel: String, + x: Any? + ) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun insertRow() = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateRow() = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun deleteRow() = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun refreshRow() = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun cancelRowUpdates() = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun moveToInsertRow() = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun moveToCurrentRow() = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateRef(columnIndex: Int, x: Ref?) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateRef(columnLabel: String, x: Ref?) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateBlob(columnIndex: Int, x: Blob?) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateBlob(columnLabel: String, x: Blob?) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateBlob( + columnIndex: Int, + inputStream: InputStream?, + length: Long + ) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateBlob( + columnLabel: String, + inputStream: InputStream?, + length: Long + ) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateBlob( + columnIndex: Int, + inputStream: InputStream? + ) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateBlob( + columnLabel: String, + inputStream: InputStream? + ) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateClob(columnIndex: Int, x: Clob?) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateClob(columnLabel: String, x: Clob?) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateClob( + columnIndex: Int, + reader: Reader?, + length: Long + ) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateClob( + columnLabel: String, + reader: Reader?, + length: Long + ) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateClob(columnIndex: Int, reader: Reader?) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateClob(columnLabel: String, reader: Reader?) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateArray(columnIndex: Int, x: java.sql.Array?) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateArray( + columnLabel: String, + x: java.sql.Array? + ) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateRowId(columnIndex: Int, x: RowId?) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateRowId(columnLabel: String, x: RowId?) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateNString(columnIndex: Int, nString: String?) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateNString( + columnLabel: String, + nString: String? + ) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateNClob(columnIndex: Int, nClob: NClob?) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateNClob(columnLabel: String, nClob: NClob?) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateNClob( + columnIndex: Int, + reader: Reader?, + length: Long + ) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateNClob( + columnLabel: String, + reader: Reader?, + length: Long + ) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateNClob(columnIndex: Int, reader: Reader?) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateNClob(columnLabel: String, reader: Reader?) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateSQLXML(columnIndex: Int, xmlObject: SQLXML?) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateSQLXML( + columnLabel: String, + xmlObject: SQLXML? + ) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateNCharacterStream( + columnIndex: Int, + x: Reader?, + length: Long + ) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateNCharacterStream( + columnLabel: String, + reader: Reader?, + length: Long + ) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateNCharacterStream( + columnIndex: Int, + x: Reader? + ) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun updateNCharacterStream( + columnLabel: String, + reader: Reader? + ) = throw SQLFeatureNotSupportedException(READ_ONLY_ERROR) + + override fun unwrap(iface: Class): T = if (iface.isAssignableFrom(this::class.java)) { + @Suppress("UNCHECKED_CAST") + this as T + } else if (iface.isAssignableFrom(ICursor::class.java)) { + @Suppress("UNCHECKED_CAST") + cursor as T + } else { + throw SQLException("Cannot unwrap to ${iface.name}") + } + + override fun isWrapperFor( + iface: Class<*> + ): Boolean = iface.isAssignableFrom(this::class.java) || iface.isAssignableFrom(ICursor::class.java) + + private fun checkClosed() { + if (cursor.isClosed()) { + throw SQLException("ResultSet is closed") + } + } + + private fun checkScrollable() { + if (resultSetType == ResultSet.TYPE_FORWARD_ONLY) { + throw SQLException("ResultSet is TYPE_FORWARD_ONLY and does not support this operation") + } + } + + private fun validateColumnIndex(columnIndex: Int) { + if (columnIndex < 1 || columnIndex > cursor.columnCount) { + throw SQLException("Column index $columnIndex is out of range (1, ${cursor.columnCount})") + } + } +} diff --git a/selekt-jdbc/src/main/kotlin/com/bloomberg/selekt/jdbc/result/JdbcResultSetMetaData.kt b/selekt-jdbc/src/main/kotlin/com/bloomberg/selekt/jdbc/result/JdbcResultSetMetaData.kt new file mode 100644 index 0000000000..0354edf335 --- /dev/null +++ b/selekt-jdbc/src/main/kotlin/com/bloomberg/selekt/jdbc/result/JdbcResultSetMetaData.kt @@ -0,0 +1,195 @@ +/* + * Copyright 2026 Bloomberg Finance L.P. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.bloomberg.selekt.jdbc.result + +import com.bloomberg.selekt.ICursor +import com.bloomberg.selekt.jdbc.util.TypeMapping +import java.sql.ResultSetMetaData +import java.sql.SQLException +import java.sql.Types +import javax.annotation.concurrent.NotThreadSafe +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +@NotThreadSafe +@Suppress("TooGenericExceptionCaught") +internal class JdbcResultSetMetaData( + private val cursor: ICursor +) : ResultSetMetaData { + companion object { + private val logger: Logger = LoggerFactory.getLogger(JdbcResultSetMetaData::class.java) + + private const val BOOLEAN_DISPLAY_SIZE = 5 + private const val TINYINT_DISPLAY_SIZE = 4 + private const val SMALLINT_DISPLAY_SIZE = 6 + private const val INTEGER_DISPLAY_SIZE = 11 + private const val BIGINT_DISPLAY_SIZE = 20 + private const val FLOAT_DISPLAY_SIZE = 15 + private const val DOUBLE_DISPLAY_SIZE = 24 + private const val DATE_DISPLAY_SIZE = 10 + private const val TIME_DISPLAY_SIZE = 8 + private const val TIMESTAMP_DISPLAY_SIZE = 23 + } + + override fun getColumnCount(): Int = cursor.columnCount + + override fun isAutoIncrement(column: Int): Boolean { + validateColumnIndex(column) + return false + } + + override fun isCaseSensitive(column: Int): Boolean { + validateColumnIndex(column) + return getColumnType(column) == Types.VARCHAR + } + + override fun isSearchable(column: Int): Boolean { + validateColumnIndex(column) + return true + } + + override fun isCurrency(column: Int): Boolean { + validateColumnIndex(column) + return false + } + + override fun isNullable(column: Int): Int { + validateColumnIndex(column) + return ResultSetMetaData.columnNullableUnknown + } + + override fun isSigned(column: Int): Boolean { + validateColumnIndex(column) + return when (getColumnType(column)) { + Types.TINYINT, Types.SMALLINT, Types.INTEGER, Types.BIGINT, + Types.REAL, Types.FLOAT, Types.DOUBLE, Types.NUMERIC, Types.DECIMAL -> true + else -> false + } + } + + override fun getColumnDisplaySize(column: Int): Int { + validateColumnIndex(column) + return when (getColumnType(column)) { + Types.BOOLEAN -> BOOLEAN_DISPLAY_SIZE + Types.TINYINT -> TINYINT_DISPLAY_SIZE + Types.SMALLINT -> SMALLINT_DISPLAY_SIZE + Types.INTEGER -> INTEGER_DISPLAY_SIZE + Types.BIGINT -> BIGINT_DISPLAY_SIZE + Types.REAL, Types.FLOAT -> FLOAT_DISPLAY_SIZE + Types.DOUBLE -> DOUBLE_DISPLAY_SIZE + Types.DATE -> DATE_DISPLAY_SIZE + Types.TIME -> TIME_DISPLAY_SIZE + Types.TIMESTAMP -> TIMESTAMP_DISPLAY_SIZE + else -> Integer.MAX_VALUE + } + } + + override fun getColumnLabel(column: Int): String { + validateColumnIndex(column) + return cursor.columnName(column - 1) + } + + override fun getColumnName(column: Int): String { + validateColumnIndex(column) + return cursor.columnName(column - 1) + } + + override fun getSchemaName(column: Int): String { + validateColumnIndex(column) + return "" + } + + override fun getPrecision(column: Int): Int { + validateColumnIndex(column) + return TypeMapping.getPrecision(getColumnType(column)) + } + + override fun getScale(column: Int): Int { + validateColumnIndex(column) + return TypeMapping.getScale(getColumnType(column)) + } + + override fun getTableName(column: Int): String { + validateColumnIndex(column) + return "" + } + + override fun getCatalogName(column: Int): String { + validateColumnIndex(column) + return "" + } + + override fun getColumnType(column: Int): Int { + validateColumnIndex(column) + return try { + if (cursor.position() >= 0 && !cursor.isBeforeFirst() && !cursor.isAfterLast()) { + val columnType = cursor.type(column - 1) + TypeMapping.toJdbcType(columnType) + } else { + Types.VARCHAR + } + } catch (e: Exception) { + logger.warn("Failed to determine column type for column {}, defaulting to VARCHAR: {}", column, e.message) + Types.VARCHAR + } + } + + override fun getColumnTypeName(column: Int): String { + validateColumnIndex(column) + return TypeMapping.getJdbcTypeName(getColumnType(column)) + } + + override fun isReadOnly(column: Int): Boolean { + validateColumnIndex(column) + return true + } + + override fun isWritable(column: Int): Boolean { + validateColumnIndex(column) + return false + } + + override fun isDefinitelyWritable(column: Int): Boolean { + validateColumnIndex(column) + return false + } + + override fun getColumnClassName(column: Int): String { + validateColumnIndex(column) + val jdbcType = getColumnType(column) + return TypeMapping.getJavaClassName(jdbcType) + } + + override fun unwrap(iface: Class): T = if (iface.isAssignableFrom(this::class.java)) { + @Suppress("UNCHECKED_CAST") + this as T + } else if (iface.isAssignableFrom(ICursor::class.java)) { + @Suppress("UNCHECKED_CAST") + cursor as T + } else { + throw SQLException("Cannot unwrap to ${iface.name}") + } + + override fun isWrapperFor(iface: Class<*>): Boolean = iface.isAssignableFrom(this::class.java) || + iface.isAssignableFrom(ICursor::class.java) + + private fun validateColumnIndex(column: Int) { + if (column !in 1..cursor.columnCount) { + throw SQLException("Column index $column is out of range (1, ${cursor.columnCount})") + } + } +} diff --git a/selekt-jdbc/src/main/kotlin/com/bloomberg/selekt/jdbc/statement/JdbcParameterMetaData.kt b/selekt-jdbc/src/main/kotlin/com/bloomberg/selekt/jdbc/statement/JdbcParameterMetaData.kt new file mode 100644 index 0000000000..e1e7ae9952 --- /dev/null +++ b/selekt-jdbc/src/main/kotlin/com/bloomberg/selekt/jdbc/statement/JdbcParameterMetaData.kt @@ -0,0 +1,85 @@ +/* + * Copyright 2026 Bloomberg Finance L.P. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.bloomberg.selekt.jdbc.statement + +import java.sql.ParameterMetaData +import java.sql.SQLException +import java.sql.Types +import javax.annotation.concurrent.NotThreadSafe + +@NotThreadSafe +internal class JdbcParameterMetaData( + private val parameterCount: Int +) : ParameterMetaData { + override fun getParameterCount(): Int = parameterCount + + override fun isNullable(param: Int): Int { + validateParameterIndex(param) + return ParameterMetaData.parameterNullableUnknown + } + + override fun isSigned(param: Int): Boolean { + validateParameterIndex(param) + return false + } + + override fun getPrecision(param: Int): Int { + validateParameterIndex(param) + return 0 + } + + override fun getScale(param: Int): Int { + validateParameterIndex(param) + return 0 + } + + override fun getParameterType(param: Int): Int { + validateParameterIndex(param) + return Types.VARCHAR + } + + override fun getParameterTypeName(param: Int): String { + validateParameterIndex(param) + return "VARCHAR" + } + + override fun getParameterClassName(param: Int): String { + validateParameterIndex(param) + return String::class.java.name + } + + override fun getParameterMode(param: Int): Int { + validateParameterIndex(param) + return ParameterMetaData.parameterModeIn + } + + override fun unwrap(iface: Class): T { + if (iface.isAssignableFrom(this::class.java)) { + @Suppress("UNCHECKED_CAST") + return this as T + } + throw SQLException("Cannot unwrap to ${iface.name}") + } + + override fun isWrapperFor(iface: Class<*>): Boolean = iface.isAssignableFrom(this::class.java) + + private fun validateParameterIndex(param: Int) { + if (param !in 1..parameterCount) { + throw SQLException("Parameter index $param is out of range (1, $parameterCount)") + } + } +} diff --git a/selekt-jdbc/src/main/kotlin/com/bloomberg/selekt/jdbc/statement/JdbcPreparedStatement.kt b/selekt-jdbc/src/main/kotlin/com/bloomberg/selekt/jdbc/statement/JdbcPreparedStatement.kt new file mode 100644 index 0000000000..0d0688759a --- /dev/null +++ b/selekt-jdbc/src/main/kotlin/com/bloomberg/selekt/jdbc/statement/JdbcPreparedStatement.kt @@ -0,0 +1,469 @@ +/* + * Copyright 2026 Bloomberg Finance L.P. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.bloomberg.selekt.jdbc.statement + +import com.bloomberg.selekt.SQLDatabase +import com.bloomberg.selekt.jdbc.connection.JdbcConnection +import com.bloomberg.selekt.jdbc.exception.SQLExceptionMapper +import com.bloomberg.selekt.jdbc.result.JdbcResultSet +import com.bloomberg.selekt.jdbc.util.TypeMapping +import java.io.InputStream +import java.io.Reader +import java.math.BigDecimal +import java.net.URL +import java.sql.BatchUpdateException +import java.sql.Blob +import java.sql.Clob +import java.sql.Date +import java.sql.NClob +import java.sql.ParameterMetaData +import java.sql.PreparedStatement +import java.sql.Ref +import java.sql.ResultSet +import java.sql.ResultSetMetaData +import java.sql.RowId +import java.sql.SQLException +import java.sql.SQLFeatureNotSupportedException +import java.sql.SQLXML +import java.sql.Statement +import java.sql.Time +import java.sql.Timestamp +import java.util.Calendar +import javax.annotation.concurrent.NotThreadSafe + +private const val INITIAL_BATCH_CHUNK_SIZE = 128 + +@Suppress("TooGenericExceptionCaught") +@NotThreadSafe +internal open class JdbcPreparedStatement( + connection: JdbcConnection, + private val database: SQLDatabase, + val sql: String, + resultSetType: Int = ResultSet.TYPE_FORWARD_ONLY, + resultSetConcurrency: Int = ResultSet.CONCUR_READ_ONLY, + resultSetHoldability: Int = ResultSet.CLOSE_CURSORS_AT_COMMIT +) : JdbcStatement(connection, database, resultSetType, resultSetConcurrency, resultSetHoldability), PreparedStatement { + @Suppress("Detekt.UseDataClass") + private class BatchChunk(val capacity: Int) { + val data = arrayOfNulls>(capacity) + var count = 0 + var next: BatchChunk? = null + } + + private val parameterCount = sql.count { it == '?' } + private val parameters = arrayOfNulls(parameterCount) + private var firstChunk: BatchChunk? = null + private var currentChunk: BatchChunk? = null + private var totalBatchCount = 0 + + private fun validateParameterIndex(parameterIndex: Int) { + checkClosed() + if (parameterIndex !in 1..parameterCount) { + throw SQLException("Parameter index $parameterIndex is out of range (1, $parameterCount)") + } + } + + override fun executeQuery(): ResultSet { + checkClosed() + return runCatching { + JdbcResultSet( + database.query(sql, buildBindArgs()), + this, + resultSetType, + resultSetConcurrency, + resultSetHoldability + ) + }.getOrElse { e -> + throw SQLExceptionMapper.mapException(e as? SQLException ?: SQLException(e.message, e)) + } + } + + override fun executeUpdate(): Int { + checkClosed() + return runCatching { + database.compileStatement(sql, buildBindArgs()).executeUpdateDelete() + }.getOrElse { e -> + throw SQLExceptionMapper.mapException(e as? SQLException ?: SQLException(e.message, e)) + } + } + + override fun execute(): Boolean { + checkClosed() + return runCatching { + if (sql.trim().uppercase().startsWith("SELECT")) { + executeQuery() + true + } else { + executeUpdate() + false + } + }.getOrElse { e -> + throw SQLExceptionMapper.mapException(e as? SQLException ?: SQLException(e.message, e)) + } + } + + override fun clearParameters() { + checkClosed() + parameters.fill(null) + } + + override fun addBatch() { + checkClosed() + if (firstChunk == null) { + firstChunk = BatchChunk(INITIAL_BATCH_CHUNK_SIZE) + currentChunk = firstChunk + } + currentChunk!!.run { + if (count == capacity) { + currentChunk = BatchChunk(totalBatchCount).also { + next = it + } + } + } + currentChunk!!.run { + if (data[count] == null) { + data[count] = Array(parameterCount) { i -> + TypeMapping.convertToSQLite(parameters[i]) + } + } + ++count + } + totalBatchCount++ + } + + override fun clearBatch() { + checkClosed() + currentChunk?.let { + for (i in 0 until it.count) { + it.data[i]?.fill(null) + } + it.count = 0 + it.next = null + firstChunk = it + } + totalBatchCount = 0 + } + + override fun executeBatch(): IntArray { + checkClosed() + if (totalBatchCount == 0) { + return IntArray(0) + } + try { + if (sql.trim().uppercase().startsWith("SELECT")) { + throw SQLException("SELECT statements are not allowed in batch execution") + } + database.batch(sql, sequence { + var chunk: BatchChunk? = firstChunk + while (chunk != null) { + for (i in 0 until chunk.count) { + chunk.data[i]!!.let { data -> + yield(data) + data.fill(null) + } + } + chunk = chunk.next + } + }) + return IntArray(totalBatchCount) { Statement.SUCCESS_NO_INFO } + } catch (e: Exception) { + SQLExceptionMapper.mapException( + e as? SQLException ?: SQLException(e.message, e) + ).run { + throw BatchUpdateException( + message ?: "Batch execution failed", + sqlState, + errorCode, + IntArray(totalBatchCount) { Statement.EXECUTE_FAILED }, + this + ) + } + } finally { + clearBatch() + } + } + + private fun buildBindArgs(): Array = Array(parameterCount) { i -> + TypeMapping.convertToSQLite(parameters[i]) + } + + override fun setNull(parameterIndex: Int, sqlType: Int) { + validateParameterIndex(parameterIndex) + parameters[parameterIndex - 1] = null + } + + override fun setNull(parameterIndex: Int, sqlType: Int, typeName: String?) { + setNull(parameterIndex, sqlType) + } + + override fun setBoolean(parameterIndex: Int, x: Boolean) { + validateParameterIndex(parameterIndex) + parameters[parameterIndex - 1] = x + } + + override fun setByte(parameterIndex: Int, x: Byte) { + validateParameterIndex(parameterIndex) + parameters[parameterIndex - 1] = x + } + + override fun setShort(parameterIndex: Int, x: Short) { + validateParameterIndex(parameterIndex) + parameters[parameterIndex - 1] = x + } + + override fun setInt(parameterIndex: Int, x: Int) { + validateParameterIndex(parameterIndex) + parameters[parameterIndex - 1] = x + } + + override fun setLong(parameterIndex: Int, x: Long) { + validateParameterIndex(parameterIndex) + parameters[parameterIndex - 1] = x + } + + override fun setFloat(parameterIndex: Int, x: Float) { + validateParameterIndex(parameterIndex) + parameters[parameterIndex - 1] = x + } + + override fun setDouble(parameterIndex: Int, x: Double) { + validateParameterIndex(parameterIndex) + parameters[parameterIndex - 1] = x + } + + override fun setBigDecimal(parameterIndex: Int, x: BigDecimal?) { + validateParameterIndex(parameterIndex) + parameters[parameterIndex - 1] = x + } + + override fun setString(parameterIndex: Int, x: String?) { + validateParameterIndex(parameterIndex) + parameters[parameterIndex - 1] = x + } + + override fun setBytes(parameterIndex: Int, x: ByteArray?) { + validateParameterIndex(parameterIndex) + parameters[parameterIndex - 1] = x + } + + override fun setDate(parameterIndex: Int, x: Date?) { + validateParameterIndex(parameterIndex) + parameters[parameterIndex - 1] = x?.toString() + } + + override fun setDate(parameterIndex: Int, x: Date?, cal: Calendar?) { + setDate(parameterIndex, x) + } + + override fun setTime(parameterIndex: Int, x: Time?) { + validateParameterIndex(parameterIndex) + parameters[parameterIndex - 1] = x?.toString() + } + + override fun setTime(parameterIndex: Int, x: Time?, cal: Calendar?) { + setTime(parameterIndex, x) + } + + override fun setTimestamp(parameterIndex: Int, x: Timestamp?) { + validateParameterIndex(parameterIndex) + parameters[parameterIndex - 1] = x?.toString() + } + + override fun setTimestamp(parameterIndex: Int, x: Timestamp?, cal: Calendar?) { + setTimestamp(parameterIndex, x) + } + + override fun setObject(parameterIndex: Int, x: Any?) { + validateParameterIndex(parameterIndex) + parameters[parameterIndex - 1] = x + } + + override fun setObject(parameterIndex: Int, x: Any?, targetSqlType: Int) { + validateParameterIndex(parameterIndex) + parameters[parameterIndex - 1] = x + } + + override fun setObject(parameterIndex: Int, x: Any?, targetSqlType: Int, scaleOrLength: Int) { + validateParameterIndex(parameterIndex) + parameters[parameterIndex - 1] = x + } + + override fun setAsciiStream(parameterIndex: Int, x: InputStream?, length: Int) { + validateParameterIndex(parameterIndex) + parameters[parameterIndex - 1] = x?.readBytes()?.toString(Charsets.US_ASCII) + } + + override fun setAsciiStream(parameterIndex: Int, x: InputStream?, length: Long) { + setAsciiStream(parameterIndex, x, length.toInt()) + } + + override fun setAsciiStream(parameterIndex: Int, x: InputStream?) { + validateParameterIndex(parameterIndex) + parameters[parameterIndex - 1] = x?.readBytes()?.toString(Charsets.US_ASCII) + } + + @Deprecated("Deprecated in Java", ReplaceWith("setCharacterStream(parameterIndex, x, length)")) + override fun setUnicodeStream(parameterIndex: Int, x: InputStream?, length: Int) { + validateParameterIndex(parameterIndex) + parameters[parameterIndex - 1] = x?.readBytes()?.toString(Charsets.UTF_16) + } + + override fun setBinaryStream(parameterIndex: Int, x: InputStream?, length: Int) { + validateParameterIndex(parameterIndex) + parameters[parameterIndex - 1] = x?.readBytes() + } + + override fun setBinaryStream(parameterIndex: Int, x: InputStream?, length: Long) { + setBinaryStream(parameterIndex, x, length.toInt()) + } + + override fun setBinaryStream(parameterIndex: Int, x: InputStream?) { + validateParameterIndex(parameterIndex) + parameters[parameterIndex - 1] = x?.readBytes() + } + + override fun setCharacterStream(parameterIndex: Int, reader: Reader?, length: Int) { + validateParameterIndex(parameterIndex) + parameters[parameterIndex - 1] = reader?.readText() + } + + override fun setCharacterStream(parameterIndex: Int, reader: Reader?, length: Long) { + setCharacterStream(parameterIndex, reader, length.toInt()) + } + + override fun setCharacterStream(parameterIndex: Int, reader: Reader?) { + validateParameterIndex(parameterIndex) + parameters[parameterIndex - 1] = reader?.readText() + } + + override fun setRef(parameterIndex: Int, x: Ref?) { + validateParameterIndex(parameterIndex) + throw SQLFeatureNotSupportedException() + } + + override fun setBlob(parameterIndex: Int, x: Blob?) { + validateParameterIndex(parameterIndex) + throw SQLFeatureNotSupportedException() + } + + override fun setBlob(parameterIndex: Int, inputStream: InputStream?, length: Long) { + validateParameterIndex(parameterIndex) + throw SQLFeatureNotSupportedException() + } + + override fun setBlob(parameterIndex: Int, inputStream: InputStream?) { + validateParameterIndex(parameterIndex) + throw SQLFeatureNotSupportedException() + } + + override fun setClob(parameterIndex: Int, x: Clob?) { + validateParameterIndex(parameterIndex) + if (x == null) { + parameters[parameterIndex - 1] = null + } else { + val content = x.getSubString(1, x.length().toInt()) + parameters[parameterIndex - 1] = content + } + } + + override fun setClob(parameterIndex: Int, reader: Reader?, length: Long) { + validateParameterIndex(parameterIndex) + if (reader == null) { + parameters[parameterIndex - 1] = null + } else { + val content = CharArray(length.toInt()) + var totalRead = 0 + while (totalRead < length) { + val read = reader.read(content, totalRead, (length - totalRead).toInt()) + if (read == -1) { + break + } + totalRead += read + } + parameters[parameterIndex - 1] = String(content, 0, totalRead) + } + } + + override fun setClob(parameterIndex: Int, reader: Reader?) { + validateParameterIndex(parameterIndex) + if (reader == null) { + parameters[parameterIndex - 1] = null + } else { + parameters[parameterIndex - 1] = reader.readText() + } + } + + override fun setArray(parameterIndex: Int, x: java.sql.Array?) { + validateParameterIndex(parameterIndex) + throw SQLFeatureNotSupportedException() + } + + override fun setURL(parameterIndex: Int, x: URL?) { + validateParameterIndex(parameterIndex) + throw SQLFeatureNotSupportedException() + } + + override fun setRowId(parameterIndex: Int, x: RowId?) { + validateParameterIndex(parameterIndex) + throw SQLFeatureNotSupportedException() + } + + override fun setNString(parameterIndex: Int, value: String?) { + validateParameterIndex(parameterIndex) + throw SQLFeatureNotSupportedException() + } + + override fun setNCharacterStream( + parameterIndex: Int, + value: Reader?, + length: Long + ) { + validateParameterIndex(parameterIndex) + throw SQLFeatureNotSupportedException() + } + + override fun setNCharacterStream(parameterIndex: Int, value: Reader?) { + validateParameterIndex(parameterIndex) + throw SQLFeatureNotSupportedException() + } + + override fun setNClob(parameterIndex: Int, value: NClob?) { + validateParameterIndex(parameterIndex) + throw SQLFeatureNotSupportedException() + } + + override fun setNClob(parameterIndex: Int, reader: Reader?, length: Long) { + validateParameterIndex(parameterIndex) + throw SQLFeatureNotSupportedException() + } + + override fun setNClob(parameterIndex: Int, reader: Reader?) { + validateParameterIndex(parameterIndex) + throw SQLFeatureNotSupportedException() + } + + override fun setSQLXML(parameterIndex: Int, xmlObject: SQLXML?) { + validateParameterIndex(parameterIndex) + throw SQLFeatureNotSupportedException() + } + + override fun getMetaData(): ResultSetMetaData { + throw SQLFeatureNotSupportedException("Metadata not available for prepared statements") + } + + override fun getParameterMetaData(): ParameterMetaData = JdbcParameterMetaData(parameterCount) +} diff --git a/selekt-jdbc/src/main/kotlin/com/bloomberg/selekt/jdbc/statement/JdbcStatement.kt b/selekt-jdbc/src/main/kotlin/com/bloomberg/selekt/jdbc/statement/JdbcStatement.kt new file mode 100644 index 0000000000..f5988bb9f6 --- /dev/null +++ b/selekt-jdbc/src/main/kotlin/com/bloomberg/selekt/jdbc/statement/JdbcStatement.kt @@ -0,0 +1,293 @@ +/* + * Copyright 2026 Bloomberg Finance L.P. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.bloomberg.selekt.jdbc.statement + +import com.bloomberg.selekt.SQLDatabase +import com.bloomberg.selekt.jdbc.connection.JdbcConnection +import com.bloomberg.selekt.jdbc.exception.SQLExceptionMapper +import com.bloomberg.selekt.jdbc.result.JdbcResultSet +import java.sql.BatchUpdateException +import java.sql.Connection +import java.sql.ResultSet +import java.sql.SQLException +import java.sql.SQLFeatureNotSupportedException +import java.sql.SQLWarning +import java.sql.Statement +import java.lang.invoke.MethodHandles +import java.lang.invoke.VarHandle +import javax.annotation.concurrent.NotThreadSafe + +@NotThreadSafe +@Suppress("TooGenericExceptionCaught") +open class JdbcStatement internal constructor( + private val connection: JdbcConnection, + private val database: SQLDatabase, + private val resultSetType: Int = ResultSet.TYPE_FORWARD_ONLY, + private val resultSetConcurrency: Int = ResultSet.CONCUR_READ_ONLY, + private val resultSetHoldability: Int = ResultSet.CLOSE_CURSORS_AT_COMMIT +) : Statement { + companion object { + private val CLOSED: VarHandle = MethodHandles.lookup() + .findVarHandle(JdbcStatement::class.java, "closed", Boolean::class.javaPrimitiveType) + } + + @Volatile + private var closed = false + private var currentResultSet: ResultSet? = null + private var updateCount = -1 + private var fetchSize = 0 + private var maxRows = 0 + private var queryTimeout = 0 + private var maxFieldSize = 0 + private var poolable = false + private var closeOnCompletion = false + private val batchedSqlStatements = mutableListOf() + + var escapeProcessing: Boolean = true + private set + + override fun executeQuery(sql: String): ResultSet { + checkClosed() + try { + connection.ensureTransaction() + val cursor = database.query(sql, emptyArray()) + currentResultSet = JdbcResultSet(cursor, this, resultSetType, resultSetConcurrency, resultSetHoldability) + updateCount = -1 + return currentResultSet!! + } catch (e: Exception) { + throw SQLExceptionMapper.mapException(e as? SQLException ?: SQLException(e.message, e)) + } + } + + override fun executeUpdate(sql: String): Int { + checkClosed() + return runCatching { + val statement = database.compileStatement(sql) + updateCount = statement.executeUpdateDelete() + currentResultSet = null + updateCount + }.getOrElse { e -> + throw SQLExceptionMapper.mapException(e as? SQLException ?: SQLException(e.message, e)) + } + } + + @Suppress("Detekt.ComplexCondition") + override fun execute(sql: String): Boolean { + checkClosed() + return sql.trim().uppercase().runCatching { + if (startsWith("SELECT") || + startsWith("WITH") || + startsWith("PRAGMA") || + startsWith("EXPLAIN")) { + executeQuery(sql) + true + } else { + executeUpdate(sql) + false + } + }.getOrElse { e -> + throw SQLExceptionMapper.mapException(e as? SQLException ?: SQLException(e.message, e)) + } + } + + override fun close() { + if (CLOSED.compareAndSet(this, false, true)) { + currentResultSet?.close() + currentResultSet = null + } + } + + override fun isClosed(): Boolean = closed + + override fun getResultSet(): ResultSet? = currentResultSet + + override fun getUpdateCount(): Int = updateCount + + override fun getMoreResults(): Boolean { + currentResultSet?.close() + currentResultSet = null + return false + } + + override fun getMoreResults(current: Int): Boolean = getMoreResults() + + override fun getConnection(): Connection = connection + + override fun getWarnings(): SQLWarning? = null + + override fun clearWarnings() {} + + override fun setCursorName(name: String?) { + throw SQLFeatureNotSupportedException("Named cursors not supported") + } + + override fun setEscapeProcessing(enable: Boolean) { + escapeProcessing = enable + } + + override fun setQueryTimeout(seconds: Int) { + if (seconds < 0) { + throw SQLException("Query timeout must be non-negative") + } + queryTimeout = seconds + } + + override fun getQueryTimeout(): Int = queryTimeout + + override fun cancel() {} + + override fun setFetchDirection(direction: Int) { + if (direction != ResultSet.FETCH_FORWARD) { + throw SQLFeatureNotSupportedException("Only FETCH_FORWARD is supported") + } + } + + override fun getFetchDirection(): Int = ResultSet.FETCH_FORWARD + + override fun setFetchSize(rows: Int) { + if (rows < 0) { + throw SQLException("Fetch size must be non-negative") + } + fetchSize = rows + } + + override fun getFetchSize(): Int = fetchSize + + override fun setMaxRows(max: Int) { + if (max < 0) { + throw SQLException("Max rows must be non-negative") + } + maxRows = max + } + + override fun getMaxRows(): Int = maxRows + + override fun setMaxFieldSize(max: Int) { + if (max < 0) { + throw SQLException("Max field size must be non-negative") + } + checkClosed() + maxFieldSize = max + } + + override fun getMaxFieldSize(): Int = maxFieldSize + + override fun getResultSetConcurrency(): Int = resultSetConcurrency + + override fun getResultSetType(): Int = resultSetType + + override fun getResultSetHoldability(): Int = resultSetHoldability + + override fun addBatch(sql: String) { + checkClosed() + if (sql.isBlank()) { + throw SQLException("SQL statement cannot be empty") + } + batchedSqlStatements.add(sql) + } + + override fun clearBatch() { + checkClosed() + batchedSqlStatements.clear() + } + + override fun executeBatch(): IntArray { + checkClosed() + return if (batchedSqlStatements.isEmpty()) { + IntArray(0) + } else { + try { + executeBatchStatements() + } finally { + clearBatch() + } + } + } + + private fun executeBatchStatements(): IntArray { + val results = mutableListOf() + for (sql in batchedSqlStatements) { + runCatching { + validateBatchSql(sql) + val updateCount = executeUpdate(sql) + results.add(updateCount) + }.onFailure { e -> + throw BatchUpdateException( + e.message ?: "Batch execution failed", + (e as SQLException).sqlState, + e.errorCode, + results.toIntArray(), + e + ) + } + } + return results.toIntArray() + } + + private fun validateBatchSql(sql: String) { + if (sql.trim().uppercase().startsWith("SELECT")) { + throw SQLException("SELECT statements are not allowed in batch execution") + } + } + + override fun setPoolable(poolable: Boolean) { + checkClosed() + this.poolable = poolable + } + + override fun isPoolable(): Boolean = poolable + + override fun closeOnCompletion() { + checkClosed() + closeOnCompletion = true + } + + override fun isCloseOnCompletion(): Boolean = closeOnCompletion + + override fun executeUpdate(sql: String, autoGeneratedKeys: Int): Int = executeUpdate(sql) + + override fun executeUpdate(sql: String, columnIndexes: IntArray): Int = executeUpdate(sql) + + override fun executeUpdate(sql: String, columnNames: Array): Int = executeUpdate(sql) + + override fun execute(sql: String, autoGeneratedKeys: Int): Boolean = execute(sql) + + override fun execute(sql: String, columnIndexes: IntArray): Boolean = execute(sql) + + override fun execute(sql: String, columnNames: Array): Boolean = execute(sql) + + override fun getGeneratedKeys(): ResultSet = throw SQLFeatureNotSupportedException("Generated keys not supported") + + override fun unwrap(iface: Class): T = if (iface.isAssignableFrom(this::class.java)) { + @Suppress("UNCHECKED_CAST") + this as T + } else if (iface.isAssignableFrom(SQLDatabase::class.java)) { + @Suppress("UNCHECKED_CAST") + return database as T + } else { + throw SQLException("Cannot unwrap to ${iface.name}") + } + + override fun isWrapperFor(iface: Class<*>): Boolean = iface.isAssignableFrom(this::class.java) || + iface.isAssignableFrom(SQLDatabase::class.java) + + protected fun checkClosed() { + if (closed) { + throw SQLException("Statement is closed") + } + } +} diff --git a/selekt-jdbc/src/main/kotlin/com/bloomberg/selekt/jdbc/util/ConnectionURL.kt b/selekt-jdbc/src/main/kotlin/com/bloomberg/selekt/jdbc/util/ConnectionURL.kt new file mode 100644 index 0000000000..9a03f030eb --- /dev/null +++ b/selekt-jdbc/src/main/kotlin/com/bloomberg/selekt/jdbc/util/ConnectionURL.kt @@ -0,0 +1,128 @@ +/* + * Copyright 2026 Bloomberg Finance L.P. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.bloomberg.selekt.jdbc.util + +import java.sql.SQLException +import java.util.Properties + +/** + * Supported format: jdbc:selekt:path/to/database.sqlite[?property=value&...] + * + * Supported properties: + * - encrypt: Enable SQLCipher encryption (true/false) + * - key: Encryption key (hex string or file path) + * - poolSize: Maximum connection pool size (integer) + * - busyTimeout: SQLite busy timeout in milliseconds (integer) + * - journalMode: SQLite journal mode (DELETE, WAL, MEMORY, etc.) + * - foreignKeys: Enable foreign key constraints (true/false) + */ +@Suppress("TooGenericExceptionCaught") +internal class ConnectionURL private constructor( + val databasePath: String, + val properties: Properties +) { + companion object { + private const val JDBC_PREFIX = "jdbc:" + private const val SELEKT_SUBPROTOCOL = "selekt:" + private const val FULL_PREFIX = "$JDBC_PREFIX$SELEKT_SUBPROTOCOL" + + @JvmStatic + fun parse(url: String): ConnectionURL { + if (!url.startsWith(FULL_PREFIX)) { + throw SQLException( + "Invalid JDBC URL format. Expected format: jdbc:selekt:path/to/database.sqlite[?properties...]" + ) + } + return try { + val (databasePath, properties) = parsePathAndProperties(url.substring(FULL_PREFIX.length)) + ConnectionURL(databasePath, properties) + } catch (e: Exception) { + throw SQLException("Failed to parse JDBC URL: $url", e) + } + } + + @JvmStatic + fun isValidUrl(url: String?): Boolean { + if (url == null || !url.startsWith(FULL_PREFIX)) { + return false + } + val pathPart = url.substring(FULL_PREFIX.length) + val questionMarkIndex = pathPart.indexOf('?') + val databasePath = if (questionMarkIndex == -1) { + pathPart + } else { + pathPart.substring(0, questionMarkIndex) + } + return databasePath.isNotBlank() + } + + private fun parsePathAndProperties(urlPart: String): Pair { + val questionMarkIndex = urlPart.indexOf('?') + val databasePath = if (questionMarkIndex == -1) { + urlPart + } else { + urlPart.substring(0, questionMarkIndex) + } + require(databasePath.isNotBlank()) { "Database path cannot be empty" } + val properties = Properties() + if (questionMarkIndex != -1 && questionMarkIndex < urlPart.length - 1) { + val queryString = urlPart.substring(questionMarkIndex + 1) + parseQueryString(queryString, properties) + } + return databasePath to properties + } + + private fun parseQueryString(queryString: String, properties: Properties) { + queryString.split('&').forEach { param -> + val equalIndex = param.indexOf('=') + if (equalIndex != -1) { + val key = param.substring(0, equalIndex).trim() + val value = param.substring(equalIndex + 1).trim() + if (key.isNotEmpty()) { + val decodedValue = java.net.URLDecoder.decode(value, "UTF-8") + properties.setProperty(key, decodedValue) + } + } + } + } + } + + fun getProperty(key: String): String? = properties.getProperty(key) + + fun getProperty(key: String, defaultValue: String): String = properties.getProperty(key, defaultValue) + + fun getBooleanProperty(key: String, defaultValue: Boolean = false): Boolean { + val value = properties.getProperty(key) ?: return defaultValue + return value.equals("true", ignoreCase = true) || value == "1" + } + + fun getIntProperty(key: String, defaultValue: Int = 0): Int { + val value = properties.getProperty(key) ?: return defaultValue + return try { + value.toInt() + } catch (e: NumberFormatException) { + defaultValue + } + } + + override fun toString(): String = "$FULL_PREFIX$databasePath" + + "${if (properties.isNotEmpty()) { "?" } else { "" } }${propertiesToQueryString()}" + + private fun propertiesToQueryString(): String = properties.entries.joinToString("&") { (key, value) -> + "$key=${java.net.URLEncoder.encode(value.toString(), "UTF-8")}" + } +} diff --git a/selekt-jdbc/src/main/kotlin/com/bloomberg/selekt/jdbc/util/TypeMapping.kt b/selekt-jdbc/src/main/kotlin/com/bloomberg/selekt/jdbc/util/TypeMapping.kt new file mode 100644 index 0000000000..d7327150e0 --- /dev/null +++ b/selekt-jdbc/src/main/kotlin/com/bloomberg/selekt/jdbc/util/TypeMapping.kt @@ -0,0 +1,326 @@ +/* + * Copyright 2026 Bloomberg Finance L.P. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.bloomberg.selekt.jdbc.util + +import com.bloomberg.selekt.ColumnType +import java.math.BigDecimal +import java.sql.Date +import java.sql.Time +import java.sql.Timestamp +import java.sql.Types +import java.time.LocalDate +import java.time.LocalDateTime +import java.time.LocalTime +import java.time.format.DateTimeFormatter +import java.time.format.DateTimeParseException +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +@Suppress("TooGenericExceptionCaught") +internal object TypeMapping { + private val logger: Logger = LoggerFactory.getLogger(TypeMapping::class.java) + + private const val BOOLEAN_PRECISION = 1 + private const val TINYINT_PRECISION = 3 + private const val SMALLINT_PRECISION = 5 + private const val INTEGER_PRECISION = 10 + private const val BIGINT_PRECISION = 19 + private const val FLOAT_PRECISION = 7 + private const val DOUBLE_PRECISION = 15 + private const val DATE_PRECISION = 10 + private const val TIME_PRECISION = 8 + private const val TIMESTAMP_PRECISION = 23 + + private const val FLOAT_SCALE = 7 + private const val DOUBLE_SCALE = 15 + + fun toJdbcType(columnType: ColumnType): Int = when (columnType) { + ColumnType.INTEGER -> Types.BIGINT + ColumnType.FLOAT -> Types.DOUBLE + ColumnType.STRING -> Types.VARCHAR + ColumnType.BLOB -> Types.VARBINARY + ColumnType.NULL -> Types.NULL + } + + fun toSelektType(jdbcType: Int): ColumnType = when (jdbcType) { + Types.BOOLEAN, + Types.TINYINT, + Types.SMALLINT, + Types.INTEGER, + Types.BIGINT -> ColumnType.INTEGER + Types.REAL, + Types.FLOAT, + Types.DOUBLE, + Types.NUMERIC, + Types.DECIMAL -> ColumnType.FLOAT + Types.CHAR, + Types.VARCHAR, + Types.LONGVARCHAR, + Types.NCHAR, + Types.NVARCHAR, + Types.LONGNVARCHAR, + Types.CLOB, + Types.NCLOB, + Types.DATE, + Types.TIME, + Types.TIMESTAMP, + Types.TIMESTAMP_WITH_TIMEZONE -> ColumnType.STRING + Types.BINARY, + Types.VARBINARY, + Types.LONGVARBINARY, + Types.BLOB -> ColumnType.BLOB + Types.NULL -> ColumnType.NULL + else -> ColumnType.STRING + } + + fun getJavaClassName(jdbcType: Int): String = when (jdbcType) { + Types.BOOLEAN -> Boolean::class.java.name + Types.TINYINT -> Byte::class.java.name + Types.SMALLINT -> Short::class.java.name + Types.INTEGER -> Int::class.java.name + Types.BIGINT -> Long::class.java.name + Types.REAL, Types.FLOAT -> Float::class.java.name + Types.DOUBLE -> Double::class.java.name + Types.NUMERIC, Types.DECIMAL -> BigDecimal::class.java.name + Types.CHAR, Types.VARCHAR, Types.LONGVARCHAR, + Types.NCHAR, Types.NVARCHAR, Types.LONGNVARCHAR, + Types.CLOB, Types.NCLOB -> String::class.java.name + Types.DATE -> Date::class.java.name + Types.TIME -> Time::class.java.name + Types.TIMESTAMP, Types.TIMESTAMP_WITH_TIMEZONE -> Timestamp::class.java.name + Types.BINARY, Types.VARBINARY, Types.LONGVARBINARY -> ByteArray::class.java.name + Types.BLOB -> java.sql.Blob::class.java.name + else -> String::class.java.name + } + + fun convertFromSQLite(value: Any?, targetJdbcType: Int): Any? = if (value == null) { + null + } else { + when (targetJdbcType) { + Types.BOOLEAN -> convertToBoolean(value) + Types.TINYINT -> convertToTinyInt(value) + Types.SMALLINT -> convertToSmallInt(value) + Types.INTEGER -> convertToInteger(value) + Types.BIGINT -> convertToBigInt(value) + Types.REAL, Types.FLOAT -> convertToFloat(value) + Types.DOUBLE -> convertToDouble(value) + Types.NUMERIC, Types.DECIMAL -> convertToDecimal(value) + Types.CHAR, Types.VARCHAR, Types.LONGVARCHAR, + Types.NCHAR, Types.NVARCHAR, Types.LONGNVARCHAR, + Types.CLOB, Types.NCLOB -> value.toString() + Types.DATE -> convertToDate(value) + Types.TIME -> convertToTime(value) + Types.TIMESTAMP, Types.TIMESTAMP_WITH_TIMEZONE -> convertToTimestamp(value) + Types.BINARY, Types.VARBINARY, Types.LONGVARBINARY -> convertToBinary(value) + else -> value + } + } + + private fun convertToBoolean(value: Any): Boolean = when (value) { + is Boolean -> value + is Number -> value.toLong() != 0L + is String -> value.equals("true", ignoreCase = true) || value == "1" + else -> false + } + + private fun convertToTinyInt(value: Any): Byte = when (value) { + is Number -> value.toByte() + is String -> value.toByteOrNull() ?: 0.toByte() + else -> 0.toByte() + } + + private fun convertToSmallInt(value: Any): Short = when (value) { + is Number -> value.toShort() + is String -> value.toShortOrNull() ?: 0.toShort() + else -> 0.toShort() + } + + private fun convertToInteger(value: Any): Int = when (value) { + is Number -> value.toInt() + is String -> value.toIntOrNull() ?: 0 + else -> 0 + } + + private fun convertToBigInt(value: Any): Long = when (value) { + is Number -> value.toLong() + is String -> value.toLongOrNull() ?: 0L + else -> 0L + } + + private fun convertToFloat(value: Any): Float = when (value) { + is Number -> value.toFloat() + is String -> value.toFloatOrNull() ?: 0f + else -> 0f + } + + private fun convertToDouble(value: Any): Double = when (value) { + is Number -> value.toDouble() + is String -> value.toDoubleOrNull() ?: 0.0 + else -> 0.0 + } + + private fun convertToDecimal(value: Any): BigDecimal = when (value) { + is Number -> BigDecimal.valueOf(value.toDouble()) + is String -> try { + BigDecimal(value) + } catch (_: NumberFormatException) { + BigDecimal.ZERO + } + else -> BigDecimal.ZERO + } + + private fun convertToDate(value: Any): Date? = when (value) { + is String -> parseDate(value) + is Number -> Date(value.toLong()) + else -> null + } + + private fun convertToTime(value: Any): Time? = when (value) { + is String -> parseTime(value) + is Number -> Time(value.toLong()) + else -> null + } + + private fun convertToTimestamp(value: Any): Timestamp? = when (value) { + is String -> parseTimestamp(value) + is Number -> Timestamp(value.toLong()) + else -> null + } + + private fun convertToBinary(value: Any): ByteArray = when (value) { + is ByteArray -> value + is String -> value.toByteArray(Charsets.UTF_8) + else -> ByteArray(0) + } + + fun convertToSQLite(value: Any?): Any? = when (value) { + null -> null + is Boolean -> if (value) 1L else 0L + is Byte -> value.toLong() + is Short -> value.toLong() + is Int -> value.toLong() + is Long -> value + is Float -> value.toDouble() + is Double -> value + is BigDecimal -> value.toDouble() + is String -> value + is Date -> value.toString() + is Time -> value.toString() + is Timestamp -> value.toString() + is LocalDate -> value.toString() + is LocalTime -> value.toString() + is LocalDateTime -> value.toString() + is ByteArray -> value + else -> value.toString() + } + + fun getJdbcTypeName(jdbcType: Int): String = when (jdbcType) { + Types.BOOLEAN -> "BOOLEAN" + Types.TINYINT -> "TINYINT" + Types.SMALLINT -> "SMALLINT" + Types.INTEGER -> "INTEGER" + Types.BIGINT -> "BIGINT" + Types.REAL -> "REAL" + Types.FLOAT -> "FLOAT" + Types.DOUBLE -> "DOUBLE" + Types.NUMERIC -> "NUMERIC" + Types.DECIMAL -> "DECIMAL" + Types.CHAR -> "CHAR" + Types.VARCHAR -> "VARCHAR" + Types.LONGVARCHAR -> "LONGVARCHAR" + Types.NCHAR -> "NCHAR" + Types.NVARCHAR -> "NVARCHAR" + Types.LONGNVARCHAR -> "LONGNVARCHAR" + Types.DATE -> "DATE" + Types.TIME -> "TIME" + Types.TIMESTAMP -> "TIMESTAMP" + Types.TIMESTAMP_WITH_TIMEZONE -> "TIMESTAMP_WITH_TIMEZONE" + Types.BINARY -> "BINARY" + Types.VARBINARY -> "VARBINARY" + Types.LONGVARBINARY -> "LONGVARBINARY" + Types.BLOB -> "BLOB" + Types.CLOB -> "CLOB" + Types.NCLOB -> "NCLOB" + Types.NULL -> "NULL" + else -> "OTHER" + } + + fun getPrecision(jdbcType: Int): Int = when (jdbcType) { + Types.BOOLEAN -> BOOLEAN_PRECISION + Types.TINYINT -> TINYINT_PRECISION + Types.SMALLINT -> SMALLINT_PRECISION + Types.INTEGER -> INTEGER_PRECISION + Types.BIGINT -> BIGINT_PRECISION + Types.REAL, Types.FLOAT -> FLOAT_PRECISION + Types.DOUBLE -> DOUBLE_PRECISION + Types.DATE -> DATE_PRECISION + Types.TIME -> TIME_PRECISION + Types.TIMESTAMP -> TIMESTAMP_PRECISION + else -> 0 + } + + fun getScale(jdbcType: Int): Int = when (jdbcType) { + Types.REAL, Types.FLOAT -> FLOAT_SCALE + Types.DOUBLE -> DOUBLE_SCALE + else -> 0 + } + + private fun parseDate(dateString: String): Date? = runCatching { + Date.valueOf(LocalDate.parse(dateString, DateTimeFormatter.ISO_LOCAL_DATE)) + }.getOrElse { + runCatching { + val timestamp = parseTimestamp(dateString) + timestamp?.let { Date(it.time) } + }.getOrElse { + null + } + } + + private fun parseTime(timeString: String): Time? = runCatching { + Time.valueOf(LocalTime.parse(timeString, DateTimeFormatter.ISO_LOCAL_TIME)) + }.getOrElse { + null + } + + private fun parseTimestamp(timestampString: String): Timestamp? = timestampString.runCatching { + when { + contains('T') -> { + val dateTime = LocalDateTime.parse(timestampString, DateTimeFormatter.ISO_LOCAL_DATE_TIME) + Timestamp.valueOf(dateTime) + } + contains(' ') -> Timestamp.valueOf(timestampString) + matches(Regex("\\d+")) -> Timestamp(timestampString.toLong()) + else -> null + } + }.getOrElse { + when (it) { + is DateTimeParseException -> { + logger.debug("Failed to parse timestamp '{}': {}", timestampString, it.message) + null + } + is NumberFormatException -> { + logger.debug("Invalid numeric timestamp '{}': {}", timestampString, it.message) + null + } + is IllegalArgumentException -> { + logger.debug("Invalid timestamp format '{}': {}", timestampString, it.message) + null + } + else -> throw it + } + } +} diff --git a/selekt-jdbc/src/test/kotlin/com/bloomberg/selekt/jdbc/connection/JdbcConnectionTest.kt b/selekt-jdbc/src/test/kotlin/com/bloomberg/selekt/jdbc/connection/JdbcConnectionTest.kt new file mode 100644 index 0000000000..b39470230c --- /dev/null +++ b/selekt-jdbc/src/test/kotlin/com/bloomberg/selekt/jdbc/connection/JdbcConnectionTest.kt @@ -0,0 +1,1169 @@ +/* + * Copyright 2026 Bloomberg Finance L.P. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.bloomberg.selekt.jdbc.connection + +import com.bloomberg.selekt.SQLDatabase +import com.bloomberg.selekt.jdbc.statement.JdbcPreparedStatement +import com.bloomberg.selekt.jdbc.statement.JdbcStatement +import com.bloomberg.selekt.jdbc.util.ConnectionURL +import java.sql.Connection +import java.sql.ResultSet +import java.sql.SQLException +import java.sql.Savepoint +import java.util.Properties +import java.util.concurrent.Executors +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertFalse +import kotlin.test.assertNotNull +import kotlin.test.assertNull +import kotlin.test.assertTrue +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.mockito.kotlin.doReturn +import org.mockito.kotlin.doThrow +import org.mockito.kotlin.mock +import org.mockito.kotlin.never +import org.mockito.kotlin.verify +import org.mockito.kotlin.whenever + +internal class JdbcConnectionTest { + private lateinit var mockDatabase: SQLDatabase + private lateinit var connectionURL: ConnectionURL + private lateinit var properties: Properties + private lateinit var connection: JdbcConnection + + @BeforeEach + fun setUp() { + mockDatabase = mock() + connectionURL = ConnectionURL.parse("jdbc:selekt:/tmp/test.db") + properties = Properties() + connection = JdbcConnection(mockDatabase, connectionURL, properties) + } + + @Test + fun autoCommitDefault() { + assertTrue(connection.autoCommit) + } + + @Test + fun setAutoCommitTrue(): Unit = connection.run { + autoCommit = true + assertTrue(autoCommit) + } + + @Test + fun setAutoCommitFalse(): Unit = connection.run { + autoCommit = false + assertFalse(autoCommit) + } + + @Test + fun commitWithAutoCommitEnabled(): Unit = connection.run { + autoCommit = true + assertFailsWith { + commit() + } + } + + @Test + fun rollbackWithAutoCommitEnabled(): Unit = connection.run { + autoCommit = true + assertFailsWith { + rollback() + } + } + + @Test + fun transactionOperations(): Unit = connection.run { + autoCommit = false + commit() + rollback() + } + + @Test + fun autoCommitSwitchingPattern(): Unit = connection.run { + assertTrue(autoCommit) + autoCommit = false + assertFalse(autoCommit) + autoCommit = true + assertTrue(autoCommit) + autoCommit = false + assertFalse(autoCommit) + autoCommit = true + assertTrue(autoCommit) + } + + @Test + fun autoCommitModeIdempotent(): Unit = connection.run { + assertTrue(autoCommit) + autoCommit = true + assertTrue(autoCommit) + autoCommit = true + assertTrue(autoCommit) + autoCommit = false + assertFalse(autoCommit) + autoCommit = false + assertFalse(autoCommit) + } + + @Test + fun createStatement(): Unit = connection.createStatement().run { + assertNotNull(this) + assertTrue(this is JdbcStatement) + } + + @Test + fun createStatementWithScrollable() { + assertFailsWith { + connection.createStatement( + ResultSet.TYPE_SCROLL_INSENSITIVE, + ResultSet.CONCUR_READ_ONLY + ) + } + } + + @Test + fun prepareStatement() { + connection.prepareStatement("SELECT * FROM test WHERE id = ?").run { + assertNotNull(this) + assertTrue(this is JdbcPreparedStatement) + } + } + + @Test + fun prepareStatementWithScrollable() { + assertFailsWith { + connection.prepareStatement( + "SELECT * FROM test", + ResultSet.TYPE_SCROLL_INSENSITIVE, + ResultSet.CONCUR_READ_ONLY + ) + } + } + + @Test + fun prepareCall() { + assertFailsWith { + connection.prepareCall("{call test()}") + } + } + + @Test + fun nativeSQL() { + "SELECT * FROM test".let { sql -> + assertEquals(sql, connection.nativeSQL(sql)) + } + } + + @Test + fun closure(): Unit = connection.run { + assertFalse(isClosed) + connection.close() + assertTrue(isClosed) + connection.close() + assertTrue(isClosed) + } + + @Test + fun operationsAfterClose(): Unit = connection.run { + close() + assertFailsWith { + createStatement() + } + assertFailsWith { + prepareStatement("SELECT 1") + } + assertFailsWith { + commit() + } + } + + @Test + fun getMetaData(): Unit = connection.metaData.let { metaData -> + assertNotNull(metaData) + assertEquals(connection, metaData.connection) + } + + @Test + fun readOnlyOperations(): Unit = connection.run { + assertFalse(isReadOnly) + isReadOnly = true + assertTrue(isReadOnly) + isReadOnly = false + assertFalse(isReadOnly) + } + + @Test + fun catalogOperations(): Unit = connection.run { + assertNull(catalog) + catalog = "test_catalog" + assertNull(catalog) + } + + @Test + fun transactionIsolation(): Unit = connection.run { + assertEquals(Connection.TRANSACTION_SERIALIZABLE, transactionIsolation) + transactionIsolation = Connection.TRANSACTION_READ_COMMITTED + assertEquals(Connection.TRANSACTION_SERIALIZABLE, transactionIsolation) + } + + @Test + fun warnings(): Unit = connection.run { + assertNull(warnings) + clearWarnings() + } + + @Test + fun clientInfo(): Unit = connection.run { + clientInfo.let { clientInfo -> + assertNotNull(clientInfo) + assertTrue(clientInfo.isEmpty) + } + setClientInfo("ApplicationName", "Test") + assertTrue(clientInfo.isEmpty) + } + + @Test + fun savepointOperations(): Unit = connection.run { + autoCommit = false + val savepoint = setSavepoint() + assertNotNull(savepoint) + val namedSavepoint = setSavepoint("test_savepoint") + assertNotNull(namedSavepoint) + assertEquals("test_savepoint", namedSavepoint.savepointName) + rollback(savepoint) + releaseSavepoint(namedSavepoint) + } + + @Test + fun savepointWithAutoCommit() { + assertFailsWith { + connection.setSavepoint() + } + } + + @Test + fun holdability(): Unit = connection.run { + assertEquals(ResultSet.CLOSE_CURSORS_AT_COMMIT, holdability) + holdability = ResultSet.HOLD_CURSORS_OVER_COMMIT + assertEquals(ResultSet.CLOSE_CURSORS_AT_COMMIT, holdability) + } + + @Test + fun schema(): Unit = connection.run { + assertNull(schema) + schema = "test_schema" + assertNull(schema) + } + + @Test + fun abort(): Unit = connection.run { + val executor = Executors.newSingleThreadExecutor() + try { + abort(executor) + assertTrue(isClosed) + } finally { + executor.shutdown() + } + } + + @Test + fun networkTimeout(): Unit = connection.run { + assertEquals(0, networkTimeout) + val executor = Executors.newSingleThreadExecutor() + try { + setNetworkTimeout(executor, 5_000) + assertEquals(0, connection.networkTimeout) + } finally { + executor.shutdown() + } + } + + @Test + fun wrapperInterface(): Unit = connection.run { + assertTrue(isWrapperFor(JdbcConnection::class.java)) + assertFalse(isWrapperFor(String::class.java)) + assertEquals(this, unwrap(JdbcConnection::class.java)) + assertFailsWith { + unwrap(String::class.java) + } + } + + @Test + fun createBlob() { + assertFailsWith { + connection.createBlob() + } + } + + @Test + fun createClob(): Unit = connection.createClob().run { + assertNotNull(this) + setString(1, "test") + assertEquals("test", getSubString(1, 4)) + } + + @Test + fun createNClob() { + assertFailsWith { + connection.createNClob() + } + } + + @Test + fun createSQLXML() { + assertFailsWith { + connection.createSQLXML() + } + } + + @Test + fun createArrayOf() { + assertFailsWith { + connection.createArrayOf("VARCHAR", arrayOf("test")) + } + } + + @Test + fun createStruct() { + assertFailsWith { + connection.createStruct("TestStruct", arrayOf("value")) + } + } + + @Test + fun isValid(): Unit = connection.run { + assertTrue(isValid(0)) + assertTrue(isValid(5)) + close() + assertFalse(isValid(0)) + } + + @Test + fun closedStateIsThreadSafe(): Unit = connection.run { + (1..10).map { + Thread { + close() + } + }.apply { + forEach { it.start() } + }.apply { + forEach { it.join() } + } + assertTrue(isClosed) + } + + @Test + fun abortCallsClose(): Unit = connection.run { + val executor = Executors.newSingleThreadExecutor() + try { + assertFalse(isClosed) + abort(executor) + assertTrue(isClosed) + } finally { + executor.shutdown() + } + } + + @Test + fun closedStatePreventsFurtherOperations(): Unit = connection.run { + close() + assertTrue(isClosed) + assertFailsWith { + createStatement() + } + assertFailsWith { + prepareStatement("SELECT 1") + } + assertFailsWith { + commit() + } + assertFailsWith { + rollback() + } + assertFailsWith { + setAutoCommit(false) + } + assertFailsWith { + setSavepoint() + } + assertFailsWith { + createClob() + } + } + + @Test + fun multipleCloseCallsAreIdempotent(): Unit = connection.run { + assertFalse(isClosed) + close() + repeat(2) { + assertTrue(isClosed) + close() + } + assertTrue(isClosed) + } + + @Test + fun createStatementWithUpdatableConcurrency() { + assertFailsWith { + connection.createStatement( + ResultSet.TYPE_FORWARD_ONLY, + ResultSet.CONCUR_UPDATABLE + ) + } + } + + @Test + fun createStatementWithAllParameters() { + connection.createStatement( + ResultSet.TYPE_FORWARD_ONLY, + ResultSet.CONCUR_READ_ONLY, + ResultSet.CLOSE_CURSORS_AT_COMMIT + ).run { + assertNotNull(this) + assertTrue(this is JdbcStatement) + } + } + + @Test + fun createStatementWithInvalidHoldability() { + connection.createStatement( + ResultSet.TYPE_FORWARD_ONLY, + ResultSet.CONCUR_READ_ONLY, + ResultSet.HOLD_CURSORS_OVER_COMMIT + ).run { + assertNotNull(this) + } + } + + @Test + fun prepareStatementWithUpdatableConcurrency() { + assertFailsWith { + connection.prepareStatement( + "SELECT * FROM test", + ResultSet.TYPE_FORWARD_ONLY, + ResultSet.CONCUR_UPDATABLE + ) + } + } + + @Test + fun prepareStatementWithAllParameters() { + connection.prepareStatement( + "SELECT * FROM test", + ResultSet.TYPE_FORWARD_ONLY, + ResultSet.CONCUR_READ_ONLY, + ResultSet.CLOSE_CURSORS_AT_COMMIT + ).run { + assertNotNull(this) + assertTrue(this is JdbcPreparedStatement) + } + } + + @Test + fun prepareStatementWithAutoGeneratedKeys() { + connection.prepareStatement( + "INSERT INTO test VALUES (?)", + java.sql.Statement.RETURN_GENERATED_KEYS + ).run { + assertNotNull(this) + assertTrue(this is JdbcPreparedStatement) + } + } + + @Test + fun prepareStatementWithColumnIndexes() { + connection.prepareStatement( + "INSERT INTO test VALUES (?)", + intArrayOf(1, 2) + ).run { + assertNotNull(this) + assertTrue(this is JdbcPreparedStatement) + } + } + + @Test + fun prepareStatementWithColumnNames() { + connection.prepareStatement( + "INSERT INTO test VALUES (?)", + arrayOf("id", "name") + ).run { + assertNotNull(this) + assertTrue(this is JdbcPreparedStatement) + } + } + + @Test + fun prepareCallWithResultSetParameters() { + assertFailsWith { + connection.prepareCall( + "{call test()}", + ResultSet.TYPE_FORWARD_ONLY, + ResultSet.CONCUR_READ_ONLY + ) + } + } + + @Test + fun prepareCallWithAllParameters() { + assertFailsWith { + connection.prepareCall( + "{call test()}", + ResultSet.TYPE_FORWARD_ONLY, + ResultSet.CONCUR_READ_ONLY, + ResultSet.CLOSE_CURSORS_AT_COMMIT + ) + } + } + + @Test + fun rollbackWithNullSavepoint(): Unit = connection.run { + autoCommit = false + rollback(null) + } + + @Test + fun getClientInfoWithName(): Unit = connection.run { + assertNull(getClientInfo("ApplicationName")) + } + + @Test + fun setClientInfoWithProperties(): Unit = connection.run { + setClientInfo(Properties().apply { + setProperty("ApplicationName", "Test") + setProperty("ClientUser", "testuser") + }) + assertNotNull(warnings) + } + + @Test + fun getTypeMap(): Unit = connection.run { + assertNotNull(typeMap) + assertTrue(typeMap.isEmpty()) + } + + @Test + fun setTypeMap(): Unit = connection.run { + setTypeMap(mutableMapOf>().apply { + this["CustomType"] = String::class.java + }) + } + + @Test + fun setHoldabilityUnsupportedValue() { + assertFailsWith { + connection.holdability = 999 + } + } + + @Test + fun isValidWithNegativeTimeout() { + assertFalse(connection.isValid(-1)) + } + + @Test + fun setNetworkTimeoutNegative() { + val executor = Executors.newSingleThreadExecutor() + try { + assertFailsWith { + connection.setNetworkTimeout(executor, -1) + } + } finally { + executor.shutdown() + } + } + + @Test + fun unwrapToSQLDatabase(): Unit = connection.run { + assertTrue(isWrapperFor(SQLDatabase::class.java)) + unwrap(SQLDatabase::class.java).let { + assertNotNull(it) + assertEquals(mockDatabase, it) + } + } + + @Test + fun warningsAccumulation(): Unit = connection.run { + setTransactionIsolation(Connection.TRANSACTION_READ_COMMITTED) + assertNotNull(warnings) + holdability = ResultSet.HOLD_CURSORS_OVER_COMMIT + assertNotNull(warnings) + setClientInfo("test", "value") + assertNotNull(warnings) + val executor = Executors.newSingleThreadExecutor() + try { + setNetworkTimeout(executor, 5000) + assertNotNull(warnings) + } finally { + executor.shutdown() + } + clearWarnings() + assertNull(warnings) + } + + @Test + fun nativeSQLAfterClose(): Unit = connection.run { + close() + assertFailsWith { + nativeSQL("SELECT 1") + } + } + + @Test + fun prepareCallAfterClose(): Unit = connection.run { + close() + assertFailsWith { + prepareCall("{call test()}") + } + } + + @Test + fun setCatalogAfterClose(): Unit = connection.run { + close() + assertFailsWith { + catalog = "test" + } + } + + @Test + fun setTransactionIsolationAfterClose(): Unit = connection.run { + close() + assertFailsWith { + transactionIsolation = Connection.TRANSACTION_SERIALIZABLE + } + } + + @Test + fun setTypeMapAfterClose(): Unit = connection.run { + close() + assertFailsWith { + setTypeMap(mutableMapOf()) + } + } + + @Test + fun setHoldabilityAfterClose(): Unit = connection.run { + close() + assertFailsWith { + holdability = ResultSet.CLOSE_CURSORS_AT_COMMIT + } + } + + @Test + fun setClientInfoStringAfterClose(): Unit = connection.run { + close() + assertFailsWith { + setClientInfo("test", "value") + } + } + + @Test + fun setClientInfoPropertiesAfterClose(): Unit = connection.run { + close() + assertFailsWith { + setClientInfo(Properties()) + } + } + + @Test + fun getClientInfoAfterClose(): Unit = connection.run { + close() + assertFailsWith { + getClientInfo("test") + } + } + + @Test + fun getClientInfoPropertiesAfterClose(): Unit = connection.run { + close() + assertFailsWith { + clientInfo + } + } + + @Test + fun setSchemaAfterClose(): Unit = connection.run { + close() + assertFailsWith { + schema = "test" + } + } + + @Test + fun setNetworkTimeoutAfterClose(): Unit = connection.run { + close() + val executor = Executors.newSingleThreadExecutor() + try { + assertFailsWith { + setNetworkTimeout(executor, 1000) + } + } finally { + executor.shutdown() + } + } + + @Test + fun setReadOnlyAfterClose(): Unit = connection.run { + close() + assertFailsWith { + isReadOnly = true + } + } + + @Test + fun releaseSavepointAfterClose() { + val savepoint = mock() + connection.run { + close() + assertFailsWith { + releaseSavepoint(savepoint) + } + } + } + + @Test + fun rollbackSavepointWithAutoCommit() { + val savepoint = mock() + connection.run { + autoCommit = true + assertFailsWith { + rollback(savepoint) + } + } + } + + @Test + fun setAutoCommitErrorHandling() { + val database = mock { + whenever(it.inTransaction) doReturn true + whenever(it.setTransactionSuccessful()) doThrow RuntimeException("Database error") + } + JdbcConnection(database, connectionURL, properties).run { + autoCommit = false + assertFailsWith { + autoCommit = true + } + } + } + + @Test + fun setAutoCommitEndTransactionErrorHandling() { + val database = mock { + whenever(it.inTransaction) doReturn true + whenever(it.endTransaction()) doThrow RuntimeException("End transaction failed") + } + JdbcConnection(database, connectionURL, properties).run { + autoCommit = false + assertFailsWith { + autoCommit = true + } + } + } + + @Test + fun commitErrorHandling() { + val database = mock { + whenever(it.inTransaction) doReturn true + whenever(it.setTransactionSuccessful()) doThrow RuntimeException("Commit failed") + } + JdbcConnection(database, connectionURL, properties).run { + autoCommit = false + assertFailsWith { + commit() + } + } + } + + @Test + fun rollbackErrorHandling() { + val database = mock { + whenever(it.inTransaction) doReturn true + whenever(it.endTransaction()) doThrow RuntimeException("Rollback failed") + } + JdbcConnection(database, connectionURL, properties).run { + autoCommit = false + assertFailsWith { + rollback() + } + } + } + + @Test + fun rollbackSavepointErrorHandling() { + val database = mock { + whenever(it.exec("ROLLBACK TO SAVEPOINT test_sp")) doThrow RuntimeException("Savepoint rollback failed") + } + JdbcConnection(database, connectionURL, properties).run { + autoCommit = false + val savepoint = mock { + whenever(it.savepointName) doReturn "test_sp" + } + assertFailsWith { + rollback(savepoint) + } + } + } + + @Test + fun setSavepointErrorHandling() { + val database = mock() + whenever(database.exec("SAVEPOINT test_savepoint")) doThrow RuntimeException("Savepoint creation failed") + JdbcConnection(database, connectionURL, properties).run { + autoCommit = false + assertFailsWith { + setSavepoint("test_savepoint") + } + } + } + + @Test + fun setSavepointId(): Unit = connection.run { + autoCommit = false + assertEquals(0, connection.setSavepoint().savepointId) + } + + @Test + fun releaseSavepointErrorHandling() { + val database = mock { + whenever(it.exec("RELEASE SAVEPOINT test_sp")) doThrow RuntimeException("Release failed") + } + JdbcConnection(database, connectionURL, properties).run { + val savepoint = mock { + whenever(it.savepointName) doReturn "test_sp" + } + assertFailsWith { + releaseSavepoint(savepoint) + } + } + } + + @Test + fun closeWithTransactionErrorHandling() { + val database = mock { + whenever(it.inTransaction) doReturn true + whenever(it.endTransaction()) doThrow RuntimeException("Transaction end failed") + } + JdbcConnection(database, connectionURL, properties).run { + close() + assertTrue(isClosed) + } + } + + @Test + fun applyConnectionPropertiesErrorHandling() { + val database = mock { + whenever(it.exec("PRAGMA foreign_keys = 1")) doThrow RuntimeException("PRAGMA failed") + } + assertFailsWith { + JdbcConnection(database, connectionURL, properties) + } + } + + @Test + fun ensureTransactionCalled() { + val transactionDatabase = mock { + whenever(it.inTransaction) doReturn false + } + JdbcConnection(transactionDatabase, connectionURL, properties).apply { + autoCommit = false + ensureTransaction() + } + verify(transactionDatabase).beginImmediateTransaction() + } + + @Test + fun ensureTransactionNotCalledWhenInTransaction() { + val database = mock { + whenever(it.inTransaction) doReturn true + } + JdbcConnection(database, connectionURL, properties).run { + autoCommit = false + ensureTransaction() + } + verify(database, never()).beginImmediateTransaction() + } + + @Test + fun ensureTransactionNotCalledInAutoCommit() { + val database = mock { + whenever(it.inTransaction) doReturn false + } + JdbcConnection(database, connectionURL, properties).apply { + autoCommit = true + ensureTransaction() + } + verify(database, never()).beginImmediateTransaction() + } + + @Test + fun commitEndTransactionErrorHandling() { + val database = mock { + whenever(it.inTransaction) doReturn true + whenever(it.endTransaction()) doThrow RuntimeException("End transaction failed") + } + JdbcConnection(database, connectionURL, properties).run { + autoCommit = false + assertFailsWith { + commit() + } + } + } + + @Test + fun foreignKeysDisabled() { + mock().run { + JdbcConnection(this, connectionURL, Properties().apply { + setProperty("foreignKeys", "false") + }) + verify(this).exec("PRAGMA foreign_keys = 0") + } + } + + @Test + fun setAutoCommitWithSQLException() { + val database = mock { + whenever(it.inTransaction) doReturn true + whenever(it.setTransactionSuccessful()) doThrow SQLException("SQL error") + } + JdbcConnection(database, connectionURL, properties).run { + autoCommit = false + val exception = assertFailsWith { + autoCommit = true + } + assertNotNull(exception.message) + } + } + + @Test + fun commitWithSQLException() { + val database = mock { + whenever(it.inTransaction) doReturn true + whenever(it.setTransactionSuccessful()) doThrow SQLException("Commit SQL error") + } + JdbcConnection(database, connectionURL, properties).run { + autoCommit = false + val exception = assertFailsWith { + commit() + } + assertNotNull(exception.message) + } + } + + @Test + fun rollbackWithSQLException() { + val database = mock { + whenever(it.inTransaction) doReturn true + whenever(it.endTransaction()) doThrow SQLException("Rollback SQL error") + } + JdbcConnection(database, connectionURL, properties).run { + autoCommit = false + val exception = assertFailsWith { + rollback() + } + assertNotNull(exception.message) + } + } + + @Test + fun setSavepointWithSQLException() { + val database = mock { + whenever(it.exec("SAVEPOINT test_savepoint")) doThrow SQLException("Savepoint SQL error") + } + JdbcConnection(database, connectionURL, properties).run { + autoCommit = false + val exception = assertFailsWith { + setSavepoint("test_savepoint") + } + assertNotNull(exception.message) + } + } + + @Test + fun releaseSavepointWithSQLException() { + val savepoint = mock { + whenever(it.savepointName) doReturn "test_sp" + } + val database = mock { + whenever(it.exec("RELEASE SAVEPOINT test_sp")) doThrow SQLException("Release SQL error") + } + JdbcConnection(database, connectionURL, properties).run { + val exception = assertFailsWith { + releaseSavepoint(savepoint) + } + assertNotNull(exception.message) + } + } + + @Test + fun setTransactionIsolationUnsupported(): Unit = connection.run { + setTransactionIsolation(Connection.TRANSACTION_READ_COMMITTED) + warnings.let { + assertNotNull(it) + assertTrue(it.message?.contains("TRANSACTION_SERIALIZABLE") ?: false) + } + } + + @Test + fun setHoldabilityHoldCursorsOverCommit(): Unit = connection.run { + setHoldability(ResultSet.HOLD_CURSORS_OVER_COMMIT) + warnings.let { + assertNotNull(it) + assertTrue(it.message?.contains("HOLD_CURSORS_OVER_COMMIT") ?: false) + } + } + + @Test + fun setHoldabilityInvalid() { + assertFailsWith { + connection.setHoldability(999) + } + } + + @Test + fun setTransactionIsolationSerializable(): Unit = connection.run { + setTransactionIsolation(Connection.TRANSACTION_SERIALIZABLE) + assertEquals(Connection.TRANSACTION_SERIALIZABLE, transactionIsolation) + } + + @Test + fun setHoldabilityCloseCursorsAtCommit(): Unit = connection.run { + setHoldability(ResultSet.CLOSE_CURSORS_AT_COMMIT) + assertEquals(ResultSet.CLOSE_CURSORS_AT_COMMIT, holdability) + } + + @Test + fun rollbackSavepointWithSQLException() { + val savepoint = mock { + whenever(it.savepointName) doReturn "test_sp" + } + val database = mock { + whenever(it.exec("ROLLBACK TO SAVEPOINT test_sp")) doThrow SQLException("Rollback to savepoint failed") + } + JdbcConnection(database, connectionURL, properties).run { + val exception = assertFailsWith { + rollback(savepoint) + } + assertNotNull(exception.message) + } + } + + @Test + fun releaseSavepointWithDirectSQLException() { + val savepoint = mock { + whenever(it.savepointName) doReturn "test_sp" + } + val database = mock { + whenever(it.exec("RELEASE SAVEPOINT test_sp")) doThrow SQLException("Release failed", "HY000", 999) + } + JdbcConnection(database, connectionURL, properties).run { + val exception = assertFailsWith { + releaseSavepoint(savepoint) + } + assertNotNull(exception.message) + } + } + + @Test + fun rollbackSavepointWithRuntimeException() { + val savepoint = mock { + whenever(it.savepointName) doReturn "test_sp" + } + val database = mock { + whenever(it.exec("ROLLBACK TO SAVEPOINT test_sp")) doThrow RuntimeException("Rollback runtime error") + } + JdbcConnection(database, connectionURL, properties).run { + val exception = assertFailsWith { + rollback(savepoint) + } + assertNotNull(exception.message) + } + } + + @Test + fun applyConnectionPropertiesWithSQLException() { + val database = mock { + whenever(it.exec("PRAGMA foreign_keys = 1")) doThrow SQLException("PRAGMA failed", "HY000", 100) + } + assertFailsWith { + JdbcConnection(database, connectionURL, properties) + } + } + + @Test + fun rollbackSavepointWithNullMessageException() { + val savepoint = mock { + whenever(it.savepointName) doReturn "test_sp" + } + val database = mock { + whenever(it.exec("ROLLBACK TO SAVEPOINT test_sp")) doThrow SQLException(null, "HY000", 100) + } + JdbcConnection(database, connectionURL, properties).run { + val exception = assertFailsWith { + rollback(savepoint) + } + assertNotNull(exception) + } + } + + @Test + fun rollbackSavepointWithRuntimeExceptionNullMessage() { + val savepoint = mock { + whenever(it.savepointName) doReturn "test_sp" + } + val customException = object : RuntimeException(null as String?) {} + val database = mock { + whenever(it.exec("ROLLBACK TO SAVEPOINT test_sp")) doThrow customException + } + JdbcConnection(database, connectionURL, properties).run { + val exception = assertFailsWith { + rollback(savepoint) + } + assertNotNull(exception) + } + } + + @Test + fun rollbackSavepointWithSQLExceptionEmptyMessage() { + val savepoint = mock { + whenever(it.savepointName) doReturn "test_sp" + } + val database = mock { + whenever(it.exec("ROLLBACK TO SAVEPOINT test_sp")) doThrow SQLException("") + } + JdbcConnection(database, connectionURL, properties).run { + assertFailsWith { + rollback(savepoint) + } + } + } + + @Test + fun rollbackSavepointWithRuntimeExceptionEmptyMessage() { + val savepoint = mock { + whenever(it.savepointName) doReturn "test_sp" + } + val database = mock { + whenever(it.exec("ROLLBACK TO SAVEPOINT test_sp")) doThrow RuntimeException("") + } + JdbcConnection(database, connectionURL, properties).run { + assertFailsWith { + rollback(savepoint) + } + } + } +} diff --git a/selekt-jdbc/src/test/kotlin/com/bloomberg/selekt/jdbc/driver/SelektDataSourceTest.kt b/selekt-jdbc/src/test/kotlin/com/bloomberg/selekt/jdbc/driver/SelektDataSourceTest.kt new file mode 100644 index 0000000000..5e9bd32b13 --- /dev/null +++ b/selekt-jdbc/src/test/kotlin/com/bloomberg/selekt/jdbc/driver/SelektDataSourceTest.kt @@ -0,0 +1,553 @@ +/* + * Copyright 2026 Bloomberg Finance L.P. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.bloomberg.selekt.jdbc.driver + +import java.io.File +import java.io.PrintWriter +import java.sql.SQLException +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertFalse +import kotlin.test.assertNotNull +import kotlin.test.assertNull +import kotlin.test.assertSame +import org.junit.jupiter.api.AfterEach +import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.io.TempDir + +internal class SelektDataSourceTest { + @TempDir + lateinit var tempDir: File + + private lateinit var dataSource: SelektDataSource + + @BeforeEach + fun setUp() { + dataSource = SelektDataSource() + } + + @AfterEach + fun tearDown() { + dataSource.close() + } + + @Test + fun dataSourceConfiguration(): Unit = dataSource.run { + databasePath = "/tmp/test.db" + maxPoolSize = 20 + busyTimeout = 5_000 + journalMode = "WAL" + foreignKeys = true + setEncryption(true, "test-key") + + assertEquals("/tmp/test.db", databasePath) + assertEquals(20, maxPoolSize) + assertEquals(5_000, busyTimeout) + assertEquals("WAL", journalMode) + assertTrue(foreignKeys) + assertTrue(encryptionEnabled) + assertEquals("test-key", encryptionKey) + } + + @Test + fun invalidPoolSize(): Unit = dataSource.run { + assertFailsWith { + maxPoolSize = 0 + } + assertFailsWith { + maxPoolSize = -1 + } + } + + @Test + fun invalidBusyTimeout() { + assertFailsWith { + dataSource.busyTimeout = -1 + } + } + + @Test + fun invalidJournalMode() { + assertFailsWith { + dataSource.journalMode = "INVALID" + } + } + + @Test + fun validJournalModes(): Unit = dataSource.run { + journalMode = "DELETE" + assertEquals("DELETE", journalMode) + journalMode = "wal" + assertEquals("wal", journalMode) + journalMode = "MEMORY" + assertEquals("MEMORY", journalMode) + } + + @Test + fun loginTimeout(): Unit = dataSource.run { + assertEquals(0, loginTimeout) + loginTimeout = 30 + assertEquals(30, loginTimeout) + assertFailsWith { + loginTimeout = -1 + } + } + + @Test + fun logWriter(): Unit = dataSource.run { + val writer = PrintWriter(System.out) + logWriter = writer + assertSame(writer, logWriter) + } + + @Test + fun getConnectionWithoutConfiguration() { + assertFailsWith { + dataSource.getConnection() + } + } + + @Test + fun close(): Unit = dataSource.run { + assertFalse(isClosed()) + close() + assertTrue(isClosed()) + close() + assertTrue(isClosed()) + } + + @Test + fun getConnectionAfterClose(): Unit = dataSource.run { + databasePath = "/tmp/test.db" + close() + assertFailsWith { + getConnection() + } + } + + @Test + fun wrapperInterface(): Unit = dataSource.run { + assertTrue(isWrapperFor(SelektDataSource::class.java)) + assertFalse(isWrapperFor(String::class.java)) + val unwrapped = unwrap(SelektDataSource::class.java) + assertSame(this, unwrapped) + assertFailsWith { + unwrap(String::class.java) + } + } + + @Test + fun parentLogger(): Unit = dataSource.parentLogger.run { + assertNotNull(this) + assertEquals("com.bloomberg.selekt.jdbc.driver.SelektDataSource", name) + } + + @Test + fun urlGeneration(): Unit = dataSource.run { + databasePath = "/path/to/test.db" + maxPoolSize = 5 + busyTimeout = 2_000 + journalMode = "DELETE" + foreignKeys = false + setEncryption(true, "secret123") + + assertEquals("/path/to/test.db", databasePath) + assertEquals(5, maxPoolSize) + assertEquals(2_000, busyTimeout) + assertEquals("DELETE", journalMode) + assertFalse(foreignKeys) + assertTrue(encryptionEnabled) + assertEquals("secret123", encryptionKey) + } + + @Test + fun setEncryptionWithNullKey(): Unit = dataSource.run { + setEncryption(true, null) + assertTrue(encryptionEnabled) + assertEquals(null, encryptionKey) + setEncryption(false, "somekey") + assertFalse(encryptionEnabled) + assertEquals("somekey", encryptionKey) + } + + @Test + fun getConnectionWithUsernamePassword(): Unit = dataSource.run { + assertFailsWith { + getConnection("user", "pass") + } + } + + @Test + fun foreignKeysDefault() { + assertTrue(dataSource.foreignKeys) + } + + @Test + fun encryptionDisabledByDefault() { + assertFalse(dataSource.encryptionEnabled) + } + + @Test + fun encryptionKeyNullByDefault() { + assertEquals(null, dataSource.encryptionKey) + } + + @Test + fun maxPoolSizeDefault() { + assertEquals(10, dataSource.maxPoolSize) + } + + @Test + fun busyTimeoutDefault() { + assertEquals(2_500, dataSource.busyTimeout) + } + + @Test + fun journalModeDefault() { + assertEquals("WAL", dataSource.journalMode) + } + + @Test + fun logWriterNullByDefault() { + assertEquals(null, dataSource.logWriter) + } + + @Test + fun databasePathEmptyByDefault() { + assertEquals("", dataSource.databasePath) + } + + @Test + fun getConnectionAttempt(): Unit = dataSource.run { + databasePath = File(tempDir, "test.db").absolutePath + assertFailsWith { + getConnection() + } + } + + @Test + fun getConnectionWithUsernamePasswordAttempt(): Unit = dataSource.run { + databasePath = File(tempDir, "test2.db").absolutePath + assertFailsWith { + getConnection("user", "password") + } + } + + @Test + fun getConnectionWithEncryption(): Unit = dataSource.run { + databasePath = File(tempDir, "encrypted.db").absolutePath + setEncryption(true, "0x0123456789ABCDEF") + assertFailsWith { + getConnection() + } + } + + @Test + fun getConnectionWithFileBasedKey(): Unit = dataSource.run { + databasePath = File(tempDir, "encrypted2.db").absolutePath + setEncryption(true, File(tempDir, "keyfile.bin").apply { + writeBytes(byteArrayOf(1, 2, 3, 4, 5, 6, 7, 8)) + }.absolutePath) + assertFailsWith { + getConnection() + } + } + + @Test + fun getConnectionWithStringKey(): Unit = dataSource.run { + databasePath = File(tempDir, "encrypted3.db").absolutePath + setEncryption(true, "my-secret-key") + assertFailsWith { + getConnection() + } + } + + @Test + fun getConnectionWithEncryptionDisabled(): Unit = dataSource.run { + databasePath = File(tempDir, "plain.db").absolutePath + setEncryption(false, "ignored-key") + assertFailsWith { + getConnection() + } + } + + @Test + fun getConnectionWithCustomPoolSize(): Unit = dataSource.run { + databasePath = File(tempDir, "pooled.db").absolutePath + maxPoolSize = 20 + assertFailsWith { + getConnection() + } + } + + @Test + fun getConnectionWithCustomBusyTimeout(): Unit = dataSource.run { + databasePath = File(tempDir, "timeout.db").absolutePath + busyTimeout = 5000 + assertFailsWith { + getConnection() + } + } + + @Test + fun getConnectionWithDifferentJournalModes(): Unit = dataSource.run { + listOf( + "DELETE", + "TRUNCATE", + "PERSIST", + "MEMORY", + "WAL", + "OFF" + ).forEach { + databasePath = File(tempDir, "journal-$it.db").absolutePath + journalMode = it + assertFailsWith { + getConnection() + } + } + } + + @Test + fun getConnectionWithForeignKeysDisabled(): Unit = dataSource.run { + databasePath = File(tempDir, "nofk.db").absolutePath + foreignKeys = false + assertFailsWith { + getConnection() + } + } + + @Test + fun getConnectionWithAllProperties(): Unit = dataSource.run { + databasePath = File(tempDir, "all-props.db").absolutePath + maxPoolSize = 15 + busyTimeout = 3000 + journalMode = "DELETE" + foreignKeys = false + setEncryption(true, "test-key-123") + assertFailsWith { + getConnection() + } + } + + @Test + fun setEncryptionWithKey(): Unit = dataSource.run { + setEncryption(true, "my-key") + assertTrue(encryptionEnabled) + assertEquals("my-key", encryptionKey) + } + + @Test + fun setEncryptionWithoutKey(): Unit = dataSource.run { + setEncryption(true) + assertTrue(encryptionEnabled) + assertNull(encryptionKey) + } + + @Test + fun setEncryptionDisabled() { + dataSource.setEncryption(false, "ignored") + assertFalse(dataSource.encryptionEnabled) + assertEquals("ignored", dataSource.encryptionKey) + } + + @Test + fun logWriterNull(): Unit = dataSource.run { + logWriter = null + assertNull(logWriter) + } + + @Test + fun closeIdempotent(): Unit = dataSource.run { + close() + assertTrue(isClosed()) + close() + assertTrue(isClosed()) + } + + @Test + fun setDatabasePathUpdatesUrl(): Unit = dataSource.run { + databasePath = "/path/to/mydb.db" + assertEquals("/path/to/mydb.db", databasePath) + } + + @Test + fun getConnectionWithHexKeyUppercaseX(): Unit = dataSource.run { + databasePath = File(tempDir, "hex-upper.db").absolutePath + setEncryption(true, "0X123456") + assertFailsWith { + getConnection() + } + } + + @Test + fun getConnectionWithHexKeyLowercaseX(): Unit = dataSource.run { + databasePath = File(tempDir, "hex-lower.db").absolutePath + setEncryption(true, "0x123456") + assertFailsWith { + getConnection() + } + } + + @Test + fun getConnectionWithNullEncryptionKey(): Unit = dataSource.run { + databasePath = File(tempDir, "null-key.db").absolutePath + setEncryption(true, null) + assertFailsWith { + getConnection() + } + } + + @Test + fun getConnectionWithNonExistentKeyFile(): Unit = dataSource.run { + databasePath = File(tempDir, "nonexistent-key.db").absolutePath + setEncryption(true, "/nonexistent/path/to/keyfile.bin") + assertFailsWith { + getConnection() + } + } + + @Test + fun allJournalModesUppercase() { + listOf( + "DELETE", + "TRUNCATE", + "PERSIST", + "MEMORY", + "WAL", + "OFF" + ).forEach { + SelektDataSource().run { + try { + journalMode = it.uppercase() + assertEquals(it.uppercase(), journalMode) + } finally { + close() + } + } + } + } + + @Test + fun allJournalModesLowercase() { + listOf( + "delete", + "truncate", + "persist", + "memory", + "wal", + "off" + ).forEach { + SelektDataSource().run { + try { + journalMode = it + assertEquals(it, journalMode) + } finally { + close() + } + } + } + } + + @Test + fun allJournalModesMixedCase() { + listOf( + "Delete", + "Truncate", + "Persist", + "Memory", + "Wal", + "Off" + ).forEach { + SelektDataSource().run { + try { + journalMode = it + assertEquals(it, journalMode) + } finally { + close() + } + } + } + } + + @Test + fun validBusyTimeoutZero(): Unit = dataSource.run { + busyTimeout = 0 + assertEquals(0, busyTimeout) + } + + @Test + fun validPoolSizeOne(): Unit = dataSource.run { + maxPoolSize = 1 + assertEquals(1, maxPoolSize) + } + + @Test + fun validPoolSizeLarge(): Unit = dataSource.run { + maxPoolSize = 1_000 + assertEquals(1_000, maxPoolSize) + } + + @Test + fun validBusyTimeoutLarge(): Unit = dataSource.run { + busyTimeout = 60_000 + assertEquals(60_000, busyTimeout) + } + + @Test + fun validLoginTimeoutZero(): Unit = dataSource.run { + loginTimeout = 0 + assertEquals(0, loginTimeout) + } + + @Test + fun validLoginTimeoutPositive(): Unit = dataSource.run { + loginTimeout = 60 + assertEquals(60, loginTimeout) + } + + @Test + fun wrappedType(): Unit = dataSource.run { + assertTrue(isWrapperFor(SelektDataSource::class.java)) + assertTrue(isWrapperFor(javax.sql.DataSource::class.java)) + } + + @Test + fun notWrappedType(): Unit = dataSource.run { + assertFalse(isWrapperFor(String::class.java)) + assertFalse(isWrapperFor(List::class.java)) + } + + @Test + fun unwrapToDataSource(): Unit = dataSource.run { + val unwrapped = unwrap(javax.sql.DataSource::class.java) + assertSame(this, unwrapped) + } + + @Test + fun unwrapToSelektDataSource(): Unit = dataSource.run { + assertSame(this, unwrap(SelektDataSource::class.java)) + } + + @Test + fun unwrapToInvalidType() { + assertFailsWith { + dataSource.unwrap(List::class.java) + } + } +} diff --git a/selekt-jdbc/src/test/kotlin/com/bloomberg/selekt/jdbc/driver/SelektDriverTest.kt b/selekt-jdbc/src/test/kotlin/com/bloomberg/selekt/jdbc/driver/SelektDriverTest.kt new file mode 100644 index 0000000000..4a0e2d8c03 --- /dev/null +++ b/selekt-jdbc/src/test/kotlin/com/bloomberg/selekt/jdbc/driver/SelektDriverTest.kt @@ -0,0 +1,373 @@ +/* + * Copyright 2026 Bloomberg Finance L.P. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.bloomberg.selekt.jdbc.driver + +import java.sql.Connection +import java.sql.DriverManager +import java.sql.DriverPropertyInfo +import java.sql.SQLException +import java.util.Properties +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertFalse +import kotlin.test.assertNotNull +import kotlin.test.assertTrue +import org.junit.jupiter.api.AfterEach +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test + +internal class SelektDriverTest { + private lateinit var driver: SelektDriver + private val connections = mutableListOf() + + @BeforeEach + fun setUp() { + driver = SelektDriver() + } + + @AfterEach + fun tearDown() { + connections.run { + forEach { + if (!it.isClosed) { + it.close() + } + } + clear() + } + } + + @Test + fun driverRegistration() { + assertTrue(DriverManager.getDrivers().toList().any { it is SelektDriver }) + } + + @Test + fun acceptsValidURLs() { + driver.run { + listOf( + "jdbc:selekt:/path/to/test.db", + "jdbc:selekt:/path/to/test.db?prop=value", + "jdbc:selekt:./relative/path.db", + ).forEach { + assertTrue(acceptsURL(it)) + } + } + } + + @Test + fun rejectsInvalidURLs() { + driver.run { + listOf( + "jdbc:sqlite:/path/to/test.db", + "jdbc:mysql://localhost:3306/test", + "invalid://url", + null + ).forEach { + assertFalse(acceptsURL(it)) + } + } + } + + @Test + fun driverConnects() { + assertFailsWith { + driver.connect("jdbc:selekt:/tmp/test.db", Properties()) + } + } + + @Test + fun connectWithInvalidURL() { + assertEquals(null, driver.connect("jdbc:mysql://localhost:3306/test", Properties())) + } + + @Test + fun getPropertyInfo() { + driver.getPropertyInfo("jdbc:selekt:/tmp/test.db", Properties()).also { + assertNotNull(it) + assertTrue(it.isNotEmpty()) + }.map(DriverPropertyInfo::name).run { + assertTrue(contains("encrypt")) + assertTrue(contains("key")) + assertTrue(contains("poolSize")) + assertTrue(contains("busyTimeout")) + assertTrue(contains("journalMode")) + assertTrue(contains("foreignKeys")) + } + } + + @Test + fun getPropertyInfoWithInvalidURL() { + assertFailsWith { + driver.getPropertyInfo("invalid://url", Properties()) + } + } + + @Test + fun driverVersion(): Unit = driver.run { + assertEquals(4, majorVersion) + assertEquals(3, minorVersion) + } + + @Test + fun jdbcCompliant() { + assertFalse(driver.jdbcCompliant()) + } + + @Test + fun getParentLogger(): Unit = driver.parentLogger.run { + assertNotNull(this) + assertEquals("com.bloomberg.selekt.jdbc.driver.SelektDriver", name) + } + + @Test + fun propertyInfoDetails(): Unit = driver.getPropertyInfo("jdbc:selekt:/tmp/test.db", Properties()).run { + find { it.name == "encrypt" }.let { + assertNotNull(it) + assertEquals("Enable SQLCipher encryption", it.description) + assertFalse(it.required) + assertEquals("false", it.value) + } + find { it.name == "key" }.let { + assertNotNull(it) + assertEquals("Encryption key (hex string or file path)", it.description) + assertFalse(it.required) + } + find { it.name == "poolSize" }.let { + assertNotNull(it) + assertEquals("Maximum connection pool size", it.description) + assertFalse(it.required) + assertEquals("10", it.value) + } + find { it.name == "busyTimeout" }.let { + assertNotNull(it) + assertEquals("SQLite busy timeout in milliseconds", it.description) + assertFalse(it.required) + } + find { it.name == "journalMode" }.let { + assertNotNull(it) + assertEquals("SQLite journal mode", it.description) + assertFalse(it.required) + assertEquals("WAL", it.value) + } + find { it.name == "foreignKeys" }.let { + assertNotNull(it) + assertEquals("Enable foreign key constraints", it.description) + assertFalse(it.required) + assertEquals("true", it.value) + } + } + + @Test + fun connectWithProperties() { + val url = "jdbc:selekt:/tmp/test.db" + val properties = Properties().apply { + setProperty("encrypt", "true") + setProperty("key", "test-key") + setProperty("poolSize", "5") + setProperty("busyTimeout", "2000") + setProperty("journalMode", "DELETE") + setProperty("foreignKeys", "false") + } + assertFailsWith { + driver.connect(url, properties) + } + } + + @Test + fun connectWithURLProperties() { + val url = "jdbc:selekt:/tmp/test.db?encrypt=true&key=test-key&poolSize=5" + val properties = Properties() + assertFailsWith { + driver.connect(url, properties) + } + } + + @Test + fun propertyInfoWithExistingProperties(): Unit = driver.getPropertyInfo( + "jdbc:selekt:/tmp/test.db", + Properties().apply { + setProperty("encrypt", "true") + setProperty("poolSize", "20") + } + ).run { + find { it.name == "encrypt" }.let { + assertNotNull(it) + assertEquals("true", it.value) + } + find { it.name == "poolSize" }.let { + assertNotNull(it) + assertEquals("20", it.value) + } + find { it.name == "journalMode" }.let { + assertNotNull(it) + assertEquals("WAL", it.value) + } + } + + @Test + fun booleanPropertyChoices(): Unit = driver.getPropertyInfo("jdbc:selekt:/tmp/test.db", Properties()).run { + for (propName in listOf("encrypt", "foreignKeys")) { + val property = find { it.name == propName } + assertNotNull(property, "Property $propName should exist") + assertNotNull(property.choices, "Property $propName should have choices") + assertEquals(2, property.choices.size, "Property $propName should have 2 choices") + assertTrue(property.choices.contains("true"), "Property $propName should have 'true' choice") + assertTrue(property.choices.contains("false"), "Property $propName should have 'false' choice") + } + } + + @Test + fun journalModeChoices() { + val propertyInfo = driver.getPropertyInfo("jdbc:selekt:/tmp/test.db", Properties()) + val journalModeProperty = propertyInfo.find { it.name == "journalMode" } + assertNotNull(journalModeProperty) + assertNotNull(journalModeProperty.choices) + val expectedModes = arrayOf("DELETE", "TRUNCATE", "PERSIST", "MEMORY", "WAL", "OFF") + assertEquals(expectedModes.size, journalModeProperty.choices.size) + for (mode in expectedModes) { + assertTrue(journalModeProperty.choices.contains(mode), "Should contain journal mode $mode") + } + } + + @Test + fun urlValidationValid() { + listOf( + "jdbc:selekt:/absolute/path/test.db", + "jdbc:selekt:./relative/path/test.db", + "jdbc:selekt:../parent/test.db", + "jdbc:selekt:/path/with spaces/test.db", + "jdbc:selekt:/path/test.db?prop=value", + "jdbc:selekt:/path/test.db?prop1=value1&prop2=value2" + ).forEach { + assertTrue(driver.acceptsURL(it), "Should accept URL: $it") + } + } + + @Test + fun urlValidationInvalid() { + listOf( + "jdbc:sqlite:/path/test.db", + "jdbc:selekt:", + "jdbc:selekt", + "selekt:/path/test.db", + "invalid://url", + "", + null + ).forEach { + assertFalse(driver.acceptsURL(it), "Should reject URL: $it") + } + } + + @Test + fun connectWithHexKey() { + val properties = Properties().apply { + setProperty("encrypt", "true") + setProperty("key", "0x0123456789ABCDEF") + } + assertFailsWith { + driver.connect("jdbc:selekt:/tmp/test.db", properties) + } + } + + @Test + fun connectWithHexKeyUppercasePrefix() { + val properties = Properties().apply { + setProperty("encrypt", "true") + setProperty("key", "0X0123456789ABCDEF") + } + assertFailsWith { + driver.connect("jdbc:selekt:/tmp/test.db", properties) + } + } + + @Test + fun connectWithEncryptFalse() { + val properties = Properties().apply { + setProperty("encrypt", "false") + setProperty("key", "some-key") + } + assertFailsWith { + driver.connect("jdbc:selekt:/tmp/test.db", properties) + } + } + + @Test + fun connectWithEncryptTrueNoKey() { + val url = "jdbc:selekt:/tmp/test.db" + val properties = Properties().apply { + setProperty("encrypt", "true") + } + assertFailsWith { + driver.connect(url, properties) + } + } + + @Test + fun getPropertyInfoWithBusyTimeout() { + val properties = Properties().apply { + setProperty("busyTimeout", "5000") + } + driver.getPropertyInfo("jdbc:selekt:/tmp/test.db", properties).find { + it.name == "busyTimeout" + }.let { + assertNotNull(it) + assertEquals("5000", it.value) + } + } + + @Test + fun propertyInfoWithAllJournalModes() { + listOf( + "DELETE", + "TRUNCATE", + "PERSIST", + "MEMORY", + "WAL", + "OFF" + ).forEach { + val properties = Properties().apply { + setProperty("journalMode", it) + } + driver.getPropertyInfo("jdbc:selekt:/tmp/test.db", properties).find { info -> + info.name == "journalMode" + }.run { + assertNotNull(this) + assertEquals(it, value) + } + } + } + + @Test + fun connectWithValidPoolSizeAndBusyTimeout() { + val properties = Properties().apply { + setProperty("poolSize", "20") + setProperty("busyTimeout", "5000") + } + assertFailsWith { + driver.connect("jdbc:selekt:/tmp/test.db", properties) + } + } + + @Test + fun connectWithNullJournalMode() { + val properties = Properties() + assertFailsWith { + driver.connect("jdbc:selekt:/tmp/test.db", properties) + } + } +} diff --git a/selekt-jdbc/src/test/kotlin/com/bloomberg/selekt/jdbc/exception/SQLExceptionMapperTest.kt b/selekt-jdbc/src/test/kotlin/com/bloomberg/selekt/jdbc/exception/SQLExceptionMapperTest.kt new file mode 100644 index 0000000000..c9c9a8b757 --- /dev/null +++ b/selekt-jdbc/src/test/kotlin/com/bloomberg/selekt/jdbc/exception/SQLExceptionMapperTest.kt @@ -0,0 +1,519 @@ +/* + * Copyright 2026 Bloomberg Finance L.P. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.bloomberg.selekt.jdbc.exception + +import com.bloomberg.selekt.SQL_ABORT +import com.bloomberg.selekt.SQL_ABORT_ROLLBACK +import com.bloomberg.selekt.SQL_AUTH +import com.bloomberg.selekt.SQL_BUSY +import com.bloomberg.selekt.SQL_CANT_OPEN +import com.bloomberg.selekt.SQL_CONSTRAINT +import com.bloomberg.selekt.SQL_CORRUPT +import com.bloomberg.selekt.SQL_DONE +import com.bloomberg.selekt.SQL_ERROR +import com.bloomberg.selekt.SQL_FULL +import com.bloomberg.selekt.SQL_IO_ERROR +import com.bloomberg.selekt.SQL_IO_ERROR_ACCESS +import com.bloomberg.selekt.SQL_IO_ERROR_BLOCKED +import com.bloomberg.selekt.SQL_IO_ERROR_CHECK_RESERVED_LOCK +import com.bloomberg.selekt.SQL_IO_ERROR_CLOSE +import com.bloomberg.selekt.SQL_IO_ERROR_CONVPATH +import com.bloomberg.selekt.SQL_IO_ERROR_DELETE +import com.bloomberg.selekt.SQL_IO_ERROR_DELETE_NO_ENT +import com.bloomberg.selekt.SQL_IO_ERROR_DIR_CLOSE +import com.bloomberg.selekt.SQL_IO_ERROR_DIR_FSTAT +import com.bloomberg.selekt.SQL_IO_ERROR_DIR_FSYNC +import com.bloomberg.selekt.SQL_IO_ERROR_FSYNC +import com.bloomberg.selekt.SQL_IO_ERROR_GET_TEMP_PATH +import com.bloomberg.selekt.SQL_IO_ERROR_LOCK +import com.bloomberg.selekt.SQL_IO_ERROR_MMAP +import com.bloomberg.selekt.SQL_IO_ERROR_NOMEM +import com.bloomberg.selekt.SQL_IO_ERROR_RDLOCK +import com.bloomberg.selekt.SQL_IO_ERROR_READ +import com.bloomberg.selekt.SQL_IO_ERROR_SEEK +import com.bloomberg.selekt.SQL_IO_ERROR_SHMLOCK +import com.bloomberg.selekt.SQL_IO_ERROR_SHMMAP +import com.bloomberg.selekt.SQL_IO_ERROR_SHMOPEN +import com.bloomberg.selekt.SQL_IO_ERROR_SHMSIZE +import com.bloomberg.selekt.SQL_IO_ERROR_SHORT_READ +import com.bloomberg.selekt.SQL_IO_ERROR_TRUNCATE +import com.bloomberg.selekt.SQL_IO_ERROR_UNLOCK +import com.bloomberg.selekt.SQL_IO_ERROR_WRITE +import com.bloomberg.selekt.SQL_LOCKED +import com.bloomberg.selekt.SQL_LOCKED_SHARED_CACHE +import com.bloomberg.selekt.SQL_LOCKED_VTAB +import com.bloomberg.selekt.SQL_MISMATCH +import com.bloomberg.selekt.SQL_MISUSE +import com.bloomberg.selekt.SQL_NOMEM +import com.bloomberg.selekt.SQL_NOTICE_RECOVER_ROLLBACK +import com.bloomberg.selekt.SQL_NOTICE_RECOVER_WAL +import com.bloomberg.selekt.SQL_NOT_A_DATABASE +import com.bloomberg.selekt.SQL_NOT_FOUND +import com.bloomberg.selekt.SQL_OK +import com.bloomberg.selekt.SQL_OK_LOAD_PERMANENTLY +import com.bloomberg.selekt.SQL_RANGE +import com.bloomberg.selekt.SQL_READONLY +import com.bloomberg.selekt.SQL_READONLY_CANT_INIT +import com.bloomberg.selekt.SQL_READONLY_CANT_LOCK +import com.bloomberg.selekt.SQL_READONLY_DB_MOVED +import com.bloomberg.selekt.SQL_READONLY_DIRECTORY +import com.bloomberg.selekt.SQL_READONLY_RECOVERY +import com.bloomberg.selekt.SQL_READONLY_ROLLBACK +import com.bloomberg.selekt.SQL_ROW +import com.bloomberg.selekt.SQL_TOO_BIG +import com.bloomberg.selekt.SQL_WARNING_AUTOINDEX +import java.sql.SQLDataException +import java.sql.SQLException +import java.sql.SQLIntegrityConstraintViolationException +import java.sql.SQLNonTransientConnectionException +import java.sql.SQLNonTransientException +import java.sql.SQLTimeoutException +import java.sql.SQLTransactionRollbackException +import java.sql.SQLTransientException +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertTrue +import org.junit.jupiter.api.Test + +internal class SQLExceptionMapperTest { + @Test + fun constraintViolationMapping(): Unit = SQLExceptionMapper.mapException( + "Constraint violation", + SQL_CONSTRAINT, + -1 + ).run { + assertTrue(this is SQLIntegrityConstraintViolationException) + assertEquals("23000", sqlState) + assertEquals(SQL_CONSTRAINT, errorCode) + message!!.let { message -> + assertTrue(message.contains("Constraint violation")) + assertTrue(message.contains("SQLITE_CONSTRAINT")) + } + } + + @Test + fun dataExceptionMapping(): Unit = SQLExceptionMapper.mapException("Type mismatch", SQL_MISMATCH, -1).run { + assertTrue(this is SQLDataException) + assertEquals("22000", sqlState) + assertEquals(SQL_MISMATCH, errorCode) + } + + @Test + fun connectionExceptionMapping(): Unit = SQLExceptionMapper.mapException( + "Cannot open database", + SQL_CANT_OPEN, + -1 + ).run { + assertTrue(this is SQLNonTransientConnectionException) + assertEquals("08001", sqlState) + assertEquals(SQL_CANT_OPEN, errorCode) + } + + @Test + fun transientExceptionMapping(): Unit = SQLExceptionMapper.mapException("Database busy", SQL_BUSY, -1).run { + assertTrue(this is SQLTransientException) + assertEquals("40001", sqlState) + assertEquals(SQL_BUSY, errorCode) + } + + @Test + fun timeoutExceptionMapping(): Unit = SQLExceptionMapper.mapException( + "I/O blocked", + SQL_BUSY, + SQL_IO_ERROR_BLOCKED + ).run { + assertTrue(this is SQLTimeoutException) + assertEquals("HYT00", sqlState) + assertEquals(SQL_BUSY, errorCode) + } + + @Test + fun transactionRollbackMapping(): Unit = SQLExceptionMapper.mapException("Transaction aborted", SQL_ABORT, -1).run { + assertTrue(this is SQLTransactionRollbackException) + assertEquals("40000", sqlState) + assertEquals(SQL_ABORT, errorCode) + } + + @Test + fun genericExceptionMapping(): Unit = SQLExceptionMapper.mapException("Unknown error", SQL_ERROR, -1).run { + assertTrue(this is SQLNonTransientException) + assertEquals("HY000", sqlState) + assertEquals(SQL_ERROR, errorCode) + } + + @Test + fun exceptionFromSelektSQLException() { + val originalMessage = "Code: $SQL_CONSTRAINT; Extended: -1; Message: Foreign key constraint failed; Context: INSERT" + val selektException = SQLException(originalMessage) + val mappedException = SQLExceptionMapper.mapException(selektException) + assertTrue(mappedException is SQLIntegrityConstraintViolationException) + assertEquals("23000", mappedException.sqlState) + } + + @Test + fun extendedCodeDescriptions(): Unit = SQLExceptionMapper.mapException( + "I/O error", + SQL_IO_ERROR, + SQL_IO_ERROR_READ + ).message!!.run { + assertTrue(contains("SQLITE_IOERR")) + assertTrue(contains("SQLITE_IOERR_READ")) + } + + @Test + fun nullCodeHandling(): Unit = SQLExceptionMapper.mapException("Generic error", -999, -1).run { + assertEquals("HY000", sqlState) + assertEquals(-999, errorCode) + } + + @Test + fun mapAllConstraintCodes() { + SQLExceptionMapper.mapException("Constraint", SQL_CONSTRAINT).run { + assertTrue(this is SQLIntegrityConstraintViolationException) + assertEquals("23000", sqlState) + } + } + + @Test + fun mapAllDataExceptionCodes() { + SQLExceptionMapper.mapException("Mismatch", SQL_MISMATCH).run { + assertTrue(this is SQLDataException) + assertEquals("22000", sqlState) + } + SQLExceptionMapper.mapException("Too big", SQL_TOO_BIG).run { + assertTrue(this is SQLDataException) + assertEquals("22001", sqlState) + } + SQLExceptionMapper.mapException("Range", SQL_RANGE).run { + assertTrue(this is SQLDataException) + assertEquals("22003", sqlState) + } + } + + @Test + fun mapAllConnectionExceptionCodes() { + SQLExceptionMapper.mapException("Can't open", SQL_CANT_OPEN).run { + assertTrue(this is SQLNonTransientConnectionException) + assertEquals("08001", sqlState) + } + SQLExceptionMapper.mapException("Not a database", SQL_NOT_A_DATABASE).run { + assertTrue(this is SQLNonTransientConnectionException) + assertEquals("08007", sqlState) + } + SQLExceptionMapper.mapException("Corrupt", SQL_CORRUPT).run { + assertTrue(this is SQLNonTransientConnectionException) + assertEquals("08007", sqlState) + } + SQLExceptionMapper.mapException("Auth", SQL_AUTH).run { + assertTrue(this is SQLNonTransientConnectionException) + assertEquals("28000", sqlState) + } + } + + @Test + fun mapTransientExceptionCodes() { + SQLExceptionMapper.mapException("Busy", SQL_BUSY).run { + assertTrue(this is SQLTransientException) + assertEquals("40001", sqlState) + } + SQLExceptionMapper.mapException("Locked", SQL_LOCKED).run { + assertTrue(this is SQLTransientException) + assertEquals("40001", sqlState) + } + SQLExceptionMapper.mapException("Locked shared cache", SQL_LOCKED_SHARED_CACHE).run { + assertTrue(this is SQLTransientException) + assertEquals("40001", sqlState) + } + SQLExceptionMapper.mapException("Locked vtab", SQL_LOCKED_VTAB).run { + assertTrue(this is SQLTransientException) + assertEquals("40001", sqlState) + } + } + + @Test + fun mapTransactionRollbackExceptionCodes() { + SQLExceptionMapper.mapException("Abort", SQL_ABORT).run { + assertTrue(this is SQLTransactionRollbackException) + assertEquals("40000", sqlState) + } + SQLExceptionMapper.mapException("Abort rollback", SQL_ABORT_ROLLBACK).run { + assertTrue(this is SQLTransactionRollbackException) + assertEquals("40000", sqlState) + } + } + + @Test + fun mapRecoverableExceptionCodes() { + SQLExceptionMapper.mapException("No memory", SQL_NOMEM).run { + assertTrue(this is java.sql.SQLRecoverableException) + assertEquals("53000", sqlState) + } + SQLExceptionMapper.mapException("I/O error", SQL_IO_ERROR, SQL_IO_ERROR_NOMEM).run { + assertTrue(this is java.sql.SQLRecoverableException) + assertEquals("53000", sqlState) + } + } + + @Test + fun mapIOErrorWithAccessExtendedCode() { + SQLExceptionMapper.mapException("I/O error", SQL_IO_ERROR, SQL_IO_ERROR_ACCESS).run { + assertTrue(this is SQLTransientException) + assertEquals("HY000", sqlState) + } + SQLExceptionMapper.mapException("I/O error", SQL_IO_ERROR, SQL_IO_ERROR_LOCK).run { + assertTrue(this is SQLTransientException) + assertEquals("HY000", sqlState) + } + SQLExceptionMapper.mapException("I/O error", SQL_IO_ERROR, SQL_IO_ERROR_UNLOCK).run { + assertTrue(this is SQLTransientException) + assertEquals("HY000", sqlState) + } + } + + @Test + fun mapIOErrorWithGenericExtendedCode() { + SQLExceptionMapper.mapException("I/O error", SQL_IO_ERROR, SQL_IO_ERROR_READ).run { + assertTrue(this is SQLNonTransientException) + assertEquals("HY000", sqlState) + } + } + + @Test + fun mapIOErrorNoExtendedCode() { + SQLExceptionMapper.mapException("I/O error", SQL_IO_ERROR).run { + assertTrue(this is SQLNonTransientException) + assertEquals("HY000", sqlState) + } + } + + @Test + fun mapNonTransientExceptionCodes() { + SQLExceptionMapper.mapException("Full", SQL_FULL).run { + assertTrue(this is SQLNonTransientException) + assertEquals("53100", sqlState) + } + SQLExceptionMapper.mapException("Read only", SQL_READONLY).run { + assertTrue(this is SQLNonTransientException) + assertEquals("25006", sqlState) + } + SQLExceptionMapper.mapException("Misuse", SQL_MISUSE).run { + assertTrue(this is SQLNonTransientException) + assertEquals("HY010", sqlState) + } + SQLExceptionMapper.mapException("Not found", SQL_NOT_FOUND).run { + assertTrue(this is SQLNonTransientException) + assertEquals("42000", sqlState) + } + SQLExceptionMapper.mapException("Error", SQL_ERROR).run { + assertTrue(this is SQLNonTransientException) + assertEquals("HY000", sqlState) + } + } + + @Test + fun mapSuccessCodes() { + SQLExceptionMapper.mapException("OK", SQL_OK).run { + assertEquals("00000", sqlState) + assertEquals(SQL_OK, errorCode) + } + SQLExceptionMapper.mapException("Row", SQL_ROW).run { + assertEquals("00000", sqlState) + assertEquals(SQL_ROW, errorCode) + } + SQLExceptionMapper.mapException("Done", SQL_DONE).run { + assertEquals("00000", sqlState) + assertEquals(SQL_DONE, errorCode) + } + } + + @Suppress("Detekt.CyclomaticComplexMethod", "Detekt.LongMethod") + @Test + fun extendedCodeDescriptionsComprehensive() { + SQLExceptionMapper.mapException("Abort", SQL_ABORT, SQL_ABORT_ROLLBACK).message!!.run { + assertTrue(contains("SQLITE_ABORT_ROLLBACK")) + } + SQLExceptionMapper.mapException("I/O", SQL_IO_ERROR, SQL_IO_ERROR_CHECK_RESERVED_LOCK).message!!.run { + assertTrue(contains("SQLITE_IOERR_CHECKRESERVEDLOCK")) + } + SQLExceptionMapper.mapException("I/O", SQL_IO_ERROR, SQL_IO_ERROR_CLOSE).message!!.run { + assertTrue(contains("SQLITE_IOERR_CLOSE")) + } + SQLExceptionMapper.mapException("I/O", SQL_IO_ERROR, SQL_IO_ERROR_CONVPATH).message!!.run { + assertTrue(contains("SQLITE_IOERR_CONVPATH")) + } + SQLExceptionMapper.mapException("I/O", SQL_IO_ERROR, SQL_IO_ERROR_DELETE).message!!.run { + assertTrue(contains("SQLITE_IOERR_DELETE")) + } + SQLExceptionMapper.mapException("I/O", SQL_IO_ERROR, SQL_IO_ERROR_DELETE_NO_ENT).message!!.run { + assertTrue(contains("SQLITE_IOERR_DELETE_NOENT")) + } + SQLExceptionMapper.mapException("I/O", SQL_IO_ERROR, SQL_IO_ERROR_DIR_CLOSE).message!!.run { + assertTrue(contains("SQLITE_IOERR_DIR_CLOSE")) + } + SQLExceptionMapper.mapException("I/O", SQL_IO_ERROR, SQL_IO_ERROR_DIR_FSYNC).message!!.run { + assertTrue(contains("SQLITE_IOERR_DIR_FSYNC")) + } + SQLExceptionMapper.mapException("I/O", SQL_IO_ERROR, SQL_IO_ERROR_DIR_FSTAT).message!!.run { + assertTrue(contains("SQLITE_IOERR_DIR_FSTAT")) + } + SQLExceptionMapper.mapException("I/O", SQL_IO_ERROR, SQL_IO_ERROR_FSYNC).message!!.run { + assertTrue(contains("SQLITE_IOERR_FSYNC")) + } + SQLExceptionMapper.mapException("I/O", SQL_IO_ERROR, SQL_IO_ERROR_GET_TEMP_PATH).message!!.run { + assertTrue(contains("SQLITE_IOERR_GETTEMPPATH")) + } + SQLExceptionMapper.mapException("I/O", SQL_IO_ERROR, SQL_IO_ERROR_MMAP).message!!.run { + assertTrue(contains("SQLITE_IOERR_MMAP")) + } + SQLExceptionMapper.mapException("I/O", SQL_IO_ERROR, SQL_IO_ERROR_RDLOCK).message!!.run { + assertTrue(contains("SQLITE_IOERR_RDLOCK")) + } + SQLExceptionMapper.mapException("I/O", SQL_IO_ERROR, SQL_IO_ERROR_SEEK).message!!.run { + assertTrue(contains("SQLITE_IOERR_SEEK")) + } + SQLExceptionMapper.mapException("I/O", SQL_IO_ERROR, SQL_IO_ERROR_SHMLOCK).message!!.run { + assertTrue(contains("SQLITE_IOERR_SHMLOCK")) + } + SQLExceptionMapper.mapException("I/O", SQL_IO_ERROR, SQL_IO_ERROR_SHMMAP).message!!.run { + assertTrue(contains("SQLITE_IOERR_SHMMAP")) + } + SQLExceptionMapper.mapException("I/O", SQL_IO_ERROR, SQL_IO_ERROR_SHMOPEN).message!!.run { + assertTrue(contains("SQLITE_IOERR_SHMOPEN")) + } + SQLExceptionMapper.mapException("I/O", SQL_IO_ERROR, SQL_IO_ERROR_SHMSIZE).message!!.run { + assertTrue(contains("SQLITE_IOERR_SHMSIZE")) + } + SQLExceptionMapper.mapException("I/O", SQL_IO_ERROR, SQL_IO_ERROR_SHORT_READ).message!!.run { + assertTrue(contains("SQLITE_IOERR_SHORT_READ")) + } + SQLExceptionMapper.mapException("I/O", SQL_IO_ERROR, SQL_IO_ERROR_TRUNCATE).message!!.run { + assertTrue(contains("SQLITE_IOERR_TRUNCATE")) + } + SQLExceptionMapper.mapException("I/O", SQL_IO_ERROR, SQL_IO_ERROR_WRITE).message!!.run { + assertTrue(contains("SQLITE_IOERR_WRITE")) + } + } + + @Test + fun readonlyExtendedCodes() { + SQLExceptionMapper.mapException("RO", SQL_READONLY, SQL_READONLY_CANT_INIT).message!!.run { + assertTrue(contains("SQLITE_READONLY_CANTINIT")) + } + SQLExceptionMapper.mapException("RO", SQL_READONLY, SQL_READONLY_CANT_LOCK).message!!.run { + assertTrue(contains("SQLITE_READONLY_CANTLOCK")) + } + SQLExceptionMapper.mapException("RO", SQL_READONLY, SQL_READONLY_DB_MOVED).message!!.run { + assertTrue(contains("SQLITE_READONLY_DBMOVED")) + } + SQLExceptionMapper.mapException("RO", SQL_READONLY, SQL_READONLY_DIRECTORY).message!!.run { + assertTrue(contains("SQLITE_READONLY_DIRECTORY")) + } + SQLExceptionMapper.mapException("RO", SQL_READONLY, SQL_READONLY_RECOVERY).message!!.run { + assertTrue(contains("SQLITE_READONLY_RECOVERY")) + } + SQLExceptionMapper.mapException("RO", SQL_READONLY, SQL_READONLY_ROLLBACK).message!!.run { + assertTrue(contains("SQLITE_READONLY_ROLLBACK")) + } + } + + @Test + fun noticeAndWarningExtendedCodes() { + SQLExceptionMapper.mapException("Notice", SQL_OK, SQL_NOTICE_RECOVER_ROLLBACK).message!!.run { + assertTrue(contains("SQLITE_NOTICE_RECOVER_ROLLBACK")) + } + SQLExceptionMapper.mapException("Notice", SQL_OK, SQL_NOTICE_RECOVER_WAL).message!!.run { + assertTrue(contains("SQLITE_NOTICE_RECOVER_WAL")) + } + SQLExceptionMapper.mapException("Warning", SQL_OK, SQL_WARNING_AUTOINDEX).message!!.run { + assertTrue(contains("SQLITE_WARNING_AUTOINDEX")) + } + SQLExceptionMapper.mapException("OK", SQL_OK, SQL_OK_LOAD_PERMANENTLY).message!!.run { + assertTrue(contains("SQLITE_OK_LOAD_PERMANENTLY")) + } + } + + @Test + fun unknownExtendedCode() { + SQLExceptionMapper.mapException("Unknown", SQL_ERROR, -5678).message!!.run { + assertTrue(contains("SQLITE_UNKNOWN_EXTENDED(-5678)")) + } + } + + @Test + fun unknownPrimaryCode() { + SQLExceptionMapper.mapException("Unknown", -1234).message!!.run { + assertTrue(contains("SQLITE_UNKNOWN(-1234)")) + } + } + + @Test + fun messageWithoutExtendedCode() { + SQLExceptionMapper.mapException("Test error", SQL_ERROR).message!!.run { + listOf( + "Test error", + "SQLITE_ERROR" + ).forEach { + assertTrue(contains(it)) + } + assertFalse(contains(";")) + } + } + + @Test + fun messageWithExtendedCode() { + SQLExceptionMapper.mapException("Test error", SQL_IO_ERROR, SQL_IO_ERROR_READ).message!!.run { + listOf( + "Test error", + "SQLITE_IOERR", + ";", + "SQLITE_IOERR_READ" + ).forEach { + assertTrue(contains(it)) + } + } + } + + @Test + fun exceptionWithCause() { + SQLExceptionMapper.mapException("Error", SQL_ERROR, -1, RuntimeException("Original cause")).run { + assertEquals(cause, this.cause) + } + } + + @Test + fun extractSQLCodeFromMessage() { + val mapped = SQLExceptionMapper.mapException( + SQLException("Code: 19; Extended: -1; Message: UNIQUE constraint failed") + ) + assertTrue(mapped is SQLIntegrityConstraintViolationException) + assertEquals(19, mapped.errorCode) + } + + @Test + fun extractExtendedSQLCodeFromMessage() { + val mapped = SQLExceptionMapper.mapException( + SQLException("Code: $SQL_BUSY; Extended: $SQL_IO_ERROR_BLOCKED; Message: I/O error blocked") + ) + assertTrue(mapped is SQLTimeoutException) + assertTrue(mapped.message!!.contains("SQLITE_IOERR_BLOCKED")) + } + + @Test + fun extractNoCodeFromMessage() { + assertEquals(-1, SQLExceptionMapper.mapException(SQLException("No code in message", "HY000", -1)).errorCode) + } +} diff --git a/selekt-jdbc/src/test/kotlin/com/bloomberg/selekt/jdbc/lob/JdbcClobTest.kt b/selekt-jdbc/src/test/kotlin/com/bloomberg/selekt/jdbc/lob/JdbcClobTest.kt new file mode 100644 index 0000000000..3e3b55f014 --- /dev/null +++ b/selekt-jdbc/src/test/kotlin/com/bloomberg/selekt/jdbc/lob/JdbcClobTest.kt @@ -0,0 +1,354 @@ +/* + * Copyright 2026 Bloomberg Finance L.P. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.bloomberg.selekt.jdbc.lob + +import java.sql.SQLException +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertTrue +import org.junit.jupiter.api.Test + +internal class JdbcClobTest { + @Test + fun emptyConstructor() { + assertEquals(0L, JdbcClob().length()) + } + + @Test + fun constructorWithInitialContent() { + val content = "Hello, World!" + val clob = JdbcClob(content) + assertEquals(content.length.toLong(), clob.length()) + assertEquals(content, clob.getSubString(1, content.length)) + } + + @Test + fun getSubString(): Unit = JdbcClob("Hello, World!").run { + assertEquals("Hello, World!", getSubString(1, 13)) + assertEquals("Hello", getSubString(1, 5)) + assertEquals("World", getSubString(8, 5)) + assertEquals("World!", getSubString(8, 20)) + } + + @Test + fun getSubStringInvalidPosition(): Unit = JdbcClob("Hello").run { + assertFailsWith { + getSubString(0, 5) + } + assertFailsWith { + getSubString(-1, 5) + } + assertFailsWith { + getSubString(10, 5) + } + } + + @Test + fun getSubStringNegativeLength() { + assertFailsWith { + JdbcClob("Hello").getSubString(1, -1) + } + } + + @Test + fun getCharacterStream() { + val content = "Hello, World!" + val clob = JdbcClob(content) + val reader = clob.getCharacterStream() + val result = reader.readText() + assertEquals(content, result) + } + + @Test + fun getCharacterStreamWithPosition() { + assertEquals("World", JdbcClob("Hello, World!").getCharacterStream(8, 5).readText()) + } + + @Test + fun getAsciiStream() { + val content = "Hello" + val clob = JdbcClob(content) + val stream = clob.getAsciiStream() + val result = stream.readBytes().toString(Charsets.US_ASCII) + assertEquals(content, result) + } + + @Test + fun positionString(): Unit = JdbcClob("Hello, World! Hello again!").run { + assertEquals(1L, position("Hello", 1)) + assertEquals(15L, position("Hello", 2)) + assertEquals(-1L, position("Goodbye", 1)) + assertEquals(8L, position("World", 1)) + } + + @Test + fun positionStringInvalidStart(): Unit = JdbcClob("Hello").run { + assertFailsWith { + position("Hello", 0) + } + } + + @Test + fun positionStringPastEnd() { + assertEquals(-1L, JdbcClob("Hello").position("Hello", 10)) + } + + @Test + fun positionClob() { + val clob = JdbcClob("Hello, World!") + val searchClob = JdbcClob("World") + assertEquals(8L, clob.position(searchClob, 1)) + } + + @Test + fun setStringSimple() { + val clob = JdbcClob() + val written = clob.setString(1, "Hello") + assertEquals(5, written) + assertEquals("Hello", clob.getSubString(1, 5)) + } + + @Test + fun setStringReplace(): Unit = JdbcClob("Hello, World!").run { + setString(8, "Earth") + assertEquals("Hello, Earth!", getSubString(1, 13)) + } + + @Test + fun setStringExtend(): Unit = JdbcClob("Hello").run { + setString(7, "World") + assertEquals("Hello World", getSubString(1, 11)) + } + + @Test + fun setStringWithOffset(): Unit = JdbcClob().run { + val written = setString(1, "Hello, World!", 7, 5) + assertEquals(5, written) + assertEquals("World", getSubString(1, 5)) + } + + @Test + fun setStringInvalidPosition(): Unit = JdbcClob().run { + assertFailsWith { + setString(0, "Hello") + } + } + + @Test + fun setStringInvalidOffset() { + val clob = JdbcClob() + assertFailsWith { + clob.setString(1, "Hello", -1, 5) + } + assertFailsWith { + clob.setString(1, "Hello", 10, 5) + } + } + + @Test + fun setStringInvalidLength() { + val clob = JdbcClob() + assertFailsWith { + clob.setString(1, "Hello", 0, -1) + } + assertFailsWith { + clob.setString(1, "Hello", 0, 10) + } + } + + @Test + fun setCharacterStream() { + val clob = JdbcClob() + val writer = clob.setCharacterStream(1) + writer.write("Hello, World!") + writer.flush() + assertEquals("Hello, World!", clob.getSubString(1, 13)) + } + + @Test + fun setCharacterStreamAtPosition() { + val clob = JdbcClob("Hello, World!") + val writer = clob.setCharacterStream(8) + writer.write("Earth!") + writer.flush() + assertEquals("Hello, Earth!", clob.getSubString(1, 13)) + } + + @Test + fun setCharacterStreamInvalidPosition() { + val clob = JdbcClob() + assertFailsWith { + clob.setCharacterStream(0) + } + } + + @Test + fun setAsciiStream() { + val clob = JdbcClob() + val stream = clob.setAsciiStream(1) + stream.write("Hello".toByteArray(Charsets.US_ASCII)) + stream.flush() + assertEquals("Hello", clob.getSubString(1, 5)) + } + + @Test + fun setAsciiStreamSingleByte() { + val clob = JdbcClob() + val stream = clob.setAsciiStream(1) + stream.write('H'.code) + stream.write('i'.code) + stream.flush() + assertEquals("Hi", clob.getSubString(1, 2)) + } + + @Test + fun setAsciiStreamReplace() { + val clob = JdbcClob("Hello, World!") + val stream = clob.setAsciiStream(8) + stream.write("Earth".toByteArray(Charsets.US_ASCII)) + stream.flush() + assertEquals("Hello, Earth!", clob.getSubString(1, 13)) + } + + @Test + fun setAsciiStreamInvalidPosition() { + val clob = JdbcClob() + assertFailsWith { + clob.setAsciiStream(0) + } + } + + @Test + fun truncate() { + val clob = JdbcClob("Hello, World!") + clob.truncate(5) + assertEquals(5L, clob.length()) + assertEquals("Hello", clob.getSubString(1, 5)) + } + + @Test + fun truncateToZero() { + val clob = JdbcClob("Hello") + clob.truncate(0) + assertEquals(0L, clob.length()) + } + + @Test + fun truncateBeyondLength() { + val clob = JdbcClob("Hello") + clob.truncate(10) + assertEquals(5L, clob.length()) + assertEquals("Hello", clob.getSubString(1, 5)) + } + + @Test + fun truncateNegative() { + val clob = JdbcClob("Hello") + assertFailsWith { + clob.truncate(-1) + } + } + + @Test + fun free() { + val clob = JdbcClob("Hello") + clob.free() + assertFailsWith { + clob.length() + } + assertFailsWith { + clob.getSubString(1, 5) + } + assertFailsWith { + clob.setString(1, "World") + } + } + + @Test + fun freeIdempotent() { + val clob = JdbcClob("Hello") + clob.free() + clob.free() + } + + @Test + fun asString() { + val content = "Hello, World!" + val clob = JdbcClob(content) + assertEquals(content, clob.asString()) + } + + @Test + fun asStringAfterModification(): Unit = JdbcClob("Hello").run { + setString(7, "World") + assertEquals("Hello World", asString()) + } + + @Test + fun asStringAfterFree(): Unit = JdbcClob("Hello").run { + free() + assertFailsWith { + asString() + } + } + + @Test + fun complexModificationSequence(): Unit = JdbcClob().run { + setString(1, "Hello") + assertEquals("Hello", asString()) + setString(6, ", ") + assertEquals("Hello, ", asString()) + setString(8, "World!") + assertEquals("Hello, World!", asString()) + setString(8, "Earth!") + assertEquals("Hello, Earth!", asString()) + truncate(7) + assertEquals("Hello, ", asString()) + } + + @Test + fun streamWriterInteraction(): Unit = JdbcClob().run { + setAsciiStream(1).run { + write("Hello".toByteArray(Charsets.US_ASCII)) + flush() + } + setCharacterStream(6).run { + write(" World") + flush() + } + assertEquals("Hello World", asString()) + } + + @Test + fun positionAfterModification(): Unit = JdbcClob("Hello, World!").run { + assertEquals(8L, position("World", 1)) + setString(8, "Earth") + assertEquals(-1L, position("World", 1)) + assertEquals(8L, position("Earth", 1)) + } + + @Test + fun setStringWithGap(): Unit = JdbcClob("Hi").apply { + setString(5, "There") + }.asString().run { + assertEquals(9, length) + assertTrue(startsWith("Hi")) + assertTrue(endsWith("There")) + assertEquals(' ', this[2]) + assertEquals(' ', this[3]) + } +} diff --git a/selekt-jdbc/src/test/kotlin/com/bloomberg/selekt/jdbc/metadata/JdbcDatabaseMetaDataTest.kt b/selekt-jdbc/src/test/kotlin/com/bloomberg/selekt/jdbc/metadata/JdbcDatabaseMetaDataTest.kt new file mode 100644 index 0000000000..b1f8724128 --- /dev/null +++ b/selekt-jdbc/src/test/kotlin/com/bloomberg/selekt/jdbc/metadata/JdbcDatabaseMetaDataTest.kt @@ -0,0 +1,1042 @@ +/* + * Copyright 2026 Bloomberg Finance L.P. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.bloomberg.selekt.jdbc.metadata + +import com.bloomberg.selekt.ColumnType +import com.bloomberg.selekt.ICursor +import com.bloomberg.selekt.SQLDatabase +import com.bloomberg.selekt.jdbc.connection.JdbcConnection +import com.bloomberg.selekt.jdbc.util.ConnectionURL +import java.sql.Connection +import java.sql.DatabaseMetaData +import java.sql.ResultSet +import java.sql.RowIdLifetime +import java.sql.SQLException +import java.sql.Types +import java.util.Properties +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertFalse +import kotlin.test.assertNotNull +import kotlin.test.assertSame +import kotlin.test.assertTrue +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.mockito.kotlin.any +import org.mockito.kotlin.doAnswer +import org.mockito.kotlin.doReturn +import org.mockito.kotlin.mock +import org.mockito.kotlin.whenever + +internal class JdbcDatabaseMetaDataTest { + private lateinit var mockDatabase: SQLDatabase + private lateinit var mockConnection: JdbcConnection + private lateinit var mockCursor: ICursor + private lateinit var metaData: JdbcDatabaseMetaData + + @BeforeEach + fun setUp() { + mockCursor = mock { + whenever(it.moveToNext()).doReturn(false) + whenever(it.columnCount).doReturn(0) + whenever(it.columnNames()).doReturn(emptyArray()) + whenever(it.isClosed()).doReturn(false) + } + mockDatabase = mock { + whenever(it.query(any(), any>())).doReturn(mockCursor) + } + val connectionURL = ConnectionURL.parse("jdbc:selekt:/tmp/test.db") + mockConnection = JdbcConnection(mockDatabase, connectionURL, Properties()) + metaData = JdbcDatabaseMetaData(mockConnection, mockDatabase, connectionURL) + } + + @Test + fun getConnection() { + assertEquals(mockConnection, metaData.connection) + } + + @Test + fun databaseProductInfo(): Unit = metaData.run { + assertEquals("SQLite", databaseProductName) + assertNotNull(databaseProductVersion) + assertTrue(databaseProductVersion.isNotEmpty()) + } + + @Test + fun driverInfo(): Unit = metaData.run { + assertEquals("Selekt JDBC Driver", driverName) + assertEquals("4.3", driverVersion) + assertEquals(4, driverMajorVersion) + assertEquals(3, driverMinorVersion) + } + + @Test + fun jdbcVersion(): Unit = metaData.run { + assertEquals(4, jdbcMajorVersion) + assertEquals(3, jdbcMinorVersion) + } + + @Test + fun sqlKeywords(): Unit = metaData.sqlKeywords.run { + assertTrue(contains("PRAGMA")) + assertTrue(contains("AUTOINCREMENT")) + } + + @Test + fun identifierInfo(): Unit = metaData.run { + assertEquals("\"", identifierQuoteString) + assertFalse(isCatalogAtStart) + assertEquals(".", catalogSeparator) + assertEquals("", schemaTerm) + assertEquals("", catalogTerm) + assertEquals("table", procedureTerm) + } + + @Test + fun transactionSupport(): Unit = metaData.run { + assertTrue(supportsTransactions()) + assertEquals(Connection.TRANSACTION_SERIALIZABLE, defaultTransactionIsolation) + assertTrue(supportsTransactionIsolationLevel(Connection.TRANSACTION_SERIALIZABLE)) + assertFalse(supportsTransactionIsolationLevel(Connection.TRANSACTION_READ_COMMITTED)) + assertFalse(supportsTransactionIsolationLevel(Connection.TRANSACTION_READ_UNCOMMITTED)) + assertFalse(supportsTransactionIsolationLevel(Connection.TRANSACTION_REPEATABLE_READ)) + } + + @Test + fun sqlSupport(): Unit = metaData.run { + assertTrue(supportsAlterTableWithAddColumn()) + assertTrue(supportsAlterTableWithDropColumn()) + assertFalse(supportsColumnAliasing()) + assertTrue(supportsConvert()) + assertFalse(supportsTableCorrelationNames()) + assertTrue(supportsDifferentTableCorrelationNames()) + assertTrue(supportsExpressionsInOrderBy()) + assertTrue(supportsOrderByUnrelated()) + assertTrue(supportsGroupBy()) + assertTrue(supportsGroupByUnrelated()) + assertTrue(supportsGroupByBeyondSelect()) + assertTrue(supportsLikeEscapeClause()) + assertFalse(supportsMultipleResultSets()) + assertFalse(supportsMultipleTransactions()) + assertTrue(supportsNonNullableColumns()) + assertTrue(supportsMinimumSQLGrammar()) + assertTrue(supportsCoreSQLGrammar()) + assertFalse(supportsExtendedSQLGrammar()) + assertTrue(supportsANSI92EntryLevelSQL()) + assertFalse(supportsANSI92IntermediateSQL()) + assertFalse(supportsANSI92FullSQL()) + } + + @Test + fun resultSetSupport(): Unit = metaData.run { + assertTrue(supportsResultSetType(ResultSet.TYPE_FORWARD_ONLY)) + assertFalse(supportsResultSetType(ResultSet.TYPE_SCROLL_INSENSITIVE)) + assertFalse(supportsResultSetType(ResultSet.TYPE_SCROLL_SENSITIVE)) + assertTrue(supportsResultSetConcurrency( + ResultSet.TYPE_FORWARD_ONLY, + ResultSet.CONCUR_READ_ONLY + )) + assertFalse(supportsResultSetConcurrency( + ResultSet.TYPE_FORWARD_ONLY, + ResultSet.CONCUR_UPDATABLE + )) + } + + @Test + fun joinSupport(): Unit = metaData.run { + assertTrue(supportsOuterJoins()) + assertTrue(supportsFullOuterJoins()) + assertTrue(supportsLimitedOuterJoins()) + } + + @Test + fun subquerySupport(): Unit = metaData.run { + assertTrue(supportsSubqueriesInComparisons()) + assertTrue(supportsSubqueriesInExists()) + assertTrue(supportsSubqueriesInIns()) + assertTrue(supportsSubqueriesInQuantifieds()) + assertTrue(supportsCorrelatedSubqueries()) + } + + @Test + fun unionSupport(): Unit = metaData.run { + assertTrue(supportsUnion()) + assertTrue(supportsUnionAll()) + } + + @Test + fun cursorSupport(): Unit = metaData.run { + assertFalse(supportsOpenCursorsAcrossCommit()) + assertFalse(supportsOpenCursorsAcrossRollback()) + assertFalse(supportsOpenStatementsAcrossCommit()) + assertFalse(supportsOpenStatementsAcrossRollback()) + } + + @Test + fun ddlSupport(): Unit = metaData.run { + assertFalse(dataDefinitionCausesTransactionCommit()) + assertFalse(dataDefinitionIgnoredInTransactions()) + assertTrue(supportsBatchUpdates()) + } + + @Test + fun namingLimits(): Unit = metaData.run { + assertEquals(0, maxBinaryLiteralLength) + assertEquals(0, maxCharLiteralLength) + assertEquals(0, maxColumnNameLength) + assertEquals(0, maxColumnsInGroupBy) + assertEquals(0, maxColumnsInIndex) + assertEquals(0, maxColumnsInOrderBy) + assertEquals(0, maxColumnsInSelect) + assertEquals(0, maxColumnsInTable) + assertEquals(0, maxConnections) + assertEquals(0, maxCursorNameLength) + assertEquals(0, maxIndexLength) + assertEquals(0, maxProcedureNameLength) + assertEquals(0, maxRowSize) + assertEquals(0, maxSchemaNameLength) + assertEquals(0, maxStatementLength) + assertEquals(0, maxStatements) + assertEquals(0, maxTableNameLength) + assertEquals(0, maxTablesInSelect) + assertEquals(0, maxUserNameLength) + } + + @Test + fun miscellaneousProperties(): Unit = metaData.run { + mockConnection.isReadOnly = true + assertTrue(isReadOnly) + assertFalse(locatorsUpdateCopy()) + assertTrue(usesLocalFiles()) + assertTrue(usesLocalFilePerTable()) + assertFalse(storesUpperCaseIdentifiers()) + assertFalse(storesLowerCaseIdentifiers()) + assertTrue(storesMixedCaseIdentifiers()) + assertFalse(storesUpperCaseQuotedIdentifiers()) + assertFalse(storesLowerCaseQuotedIdentifiers()) + assertTrue(storesMixedCaseQuotedIdentifiers()) + assertFalse(supportsMixedCaseIdentifiers()) + assertTrue(supportsMixedCaseQuotedIdentifiers()) + assertTrue(doesMaxRowSizeIncludeBlobs()) + assertFalse(nullsAreSortedHigh()) + assertFalse(nullsAreSortedLow()) + assertTrue(nullsAreSortedAtStart()) + assertFalse(nullsAreSortedAtEnd()) + assertFalse(allProceduresAreCallable()) + assertTrue(allTablesAreSelectable()) + assertEquals("jdbc:selekt:/tmp/test.db", url) + assertEquals("", userName) + } + + @Test + fun unsupportedFeatures(): Unit = metaData.run { + assertFalse(supportsStoredProcedures()) + assertFalse(supportsMultipleResultSets()) + assertFalse(supportsGetGeneratedKeys()) + assertFalse(supportsResultSetHoldability(ResultSet.HOLD_CURSORS_OVER_COMMIT)) + assertTrue(supportsResultSetHoldability(ResultSet.CLOSE_CURSORS_AT_COMMIT)) + assertTrue(supportsSavepoints()) + assertFalse(supportsNamedParameters()) + assertFalse(supportsMultipleOpenResults()) + assertFalse(supportsStatementPooling()) + assertFalse(supportsStoredFunctionsUsingCallSyntax()) + assertFalse(autoCommitFailureClosesAllResultSets()) + } + + @Test + fun getTables() { + whenever(mockDatabase.query(any(), any>())).doReturn(mock()) + assertNotNull(metaData.getTables(null, null, "%", arrayOf("TABLE"))) + } + + @Test + fun getColumns() { + whenever(mockDatabase.query(any(), any>())).doReturn(mock()) + assertNotNull(metaData.getColumns(null, null, "users", "%")) + } + + @Test + fun getPrimaryKeys() { + whenever(mockDatabase.query(any(), any>())).doReturn(mock()) + assertNotNull(metaData.getPrimaryKeys(null, null, "users")) + } + + @Test + fun getIndexInfo() { + whenever(mockDatabase.query(any(), any>())).doReturn(mock()) + assertNotNull(metaData.getIndexInfo(null, null, "users", unique = false, approximate = false)) + } + + @Test + fun unsupportedMetaDataMethods(): Unit = metaData.run { + assertFailsWith { + getProcedures(null, null, "%") + } + assertFailsWith { + getProcedureColumns(null, null, "%", "%") + } + assertFailsWith { + getSchemas() + } + assertFailsWith { + catalogs + } + assertFailsWith { + getTablePrivileges(null, null, "%") + } + assertFailsWith { + getColumnPrivileges(null, null, "users", "%") + } + assertFailsWith { + getBestRowIdentifier(null, null, "users", 0, false) + } + assertFailsWith { + getVersionColumns(null, null, "users") + } + assertFailsWith { + getImportedKeys(null, null, "users") + } + assertFailsWith { + getExportedKeys(null, null, "users") + } + assertFailsWith { + getCrossReference(null, null, "users", null, null, "orders") + } + assertFailsWith { + getUDTs(null, null, "%", null) + } + assertFailsWith { + getSuperTypes(null, null, "%") + } + assertFailsWith { + getSuperTables(null, null, "%") + } + assertFailsWith { + getAttributes(null, null, "%", "%") + } + } + + @Test + fun typeInfo() { + assertNotNull(metaData.typeInfo) + } + + @Test + fun tableTypes() { + assertNotNull(metaData.tableTypes) + } + + @Test + fun wrapperInterface(): Unit = metaData.run { + assertTrue(isWrapperFor(JdbcDatabaseMetaData::class.java)) + assertFalse(isWrapperFor(String::class.java)) + assertSame(metaData, metaData.unwrap(JdbcDatabaseMetaData::class.java)) + assertFailsWith { + unwrap(String::class.java) + } + } + + @Test + fun numericFunctions(): Unit = metaData.numericFunctions.run { + listOf( + "ABS", + "MAX", + "MIN", + "ROUND" + ).forEach { + assertTrue(contains(it)) + } + } + + @Test + fun stringFunctions(): Unit = metaData.stringFunctions.run { + listOf( + "LENGTH", + "LOWER", + "UPPER", + "SUBSTR" + ).forEach { + assertTrue(contains(it)) + } + } + + @Test + fun systemFunctions(): Unit = metaData.systemFunctions.run { + listOf( + "COALESCE", + "IFNULL", + "NULLIF" + ).forEach { + assertTrue(contains(it)) + } + } + + @Test + fun timeDateFunctions(): Unit = metaData.timeDateFunctions.run { + listOf( + "DATE", + "TIME", + "DATETIME", + "STRFTIME" + ).forEach { + assertTrue(contains(it)) + } + } + + @Test + fun searchStringEscape() { + assertEquals("\\", metaData.searchStringEscape) + } + + @Test + fun extraNameCharacters() { + assertEquals("", metaData.extraNameCharacters) + } + + @Test + fun getTablesWithSpecificTypes() { + whenever(mockDatabase.query(any(), any>())).doReturn(mockCursor) + metaData.run { + assertNotNull(getTables(null, null, "%", arrayOf("VIEW"))) + assertNotNull(getTables(null, null, "%", arrayOf("TABLE", "VIEW"))) + } + } + + @Test + fun getTablesWithNullTypes() { + whenever(mockDatabase.query(any(), any>())).doReturn(mockCursor) + assertNotNull(metaData.getTables(null, null, "%", null)) + } + + @Test + fun getTablesWithEmptyTypes() { + whenever(mockDatabase.query(any(), any>())).doReturn(mockCursor) + assertNotNull(metaData.getTables(null, null, "%", arrayOf())) + } + + @Test + fun getTablesWithoutPattern() { + whenever(mockDatabase.query(any(), any>())).doReturn(mockCursor) + assertNotNull(metaData.getTables(null, null, null, null)) + } + + @Test + fun getIndexInfoUnique() { + whenever(mockDatabase.query(any(), any>())).doReturn(mockCursor) + assertNotNull(metaData.getIndexInfo(null, null, "users", unique = true, approximate = false)) + } + + @Test + fun getIndexInfoNonUnique() { + whenever(mockDatabase.query(any(), any>())).doReturn(mockCursor) + assertNotNull(metaData.getIndexInfo(null, null, "users", unique = false, approximate = false)) + } + + @Test + fun unwrapToSQLDatabase() { + val unwrapped = metaData.unwrap(SQLDatabase::class.java) + assertSame(mockDatabase, unwrapped) + } + + @Test + fun isWrapperForSQLDatabase() { + assertTrue(metaData.isWrapperFor(SQLDatabase::class.java)) + } + + @Test + fun supportsResultSetTypeScrollInsensitive() { + assertFalse(metaData.supportsResultSetType(ResultSet.TYPE_SCROLL_INSENSITIVE)) + } + + @Test + fun supportsResultSetTypeScrollSensitive() { + assertFalse(metaData.supportsResultSetType(ResultSet.TYPE_SCROLL_SENSITIVE)) + } + + @Test + fun supportsResultSetConcurrencyUpdatable() { + assertFalse(metaData.supportsResultSetConcurrency( + ResultSet.TYPE_FORWARD_ONLY, + ResultSet.CONCUR_UPDATABLE + )) + } + + @Test + fun transactionIsolationLevels(): Unit = metaData.run { + assertFalse(supportsTransactionIsolationLevel(Connection.TRANSACTION_READ_COMMITTED)) + assertFalse(supportsTransactionIsolationLevel(Connection.TRANSACTION_READ_UNCOMMITTED)) + assertFalse(supportsTransactionIsolationLevel(Connection.TRANSACTION_REPEATABLE_READ)) + assertTrue(supportsTransactionIsolationLevel(Connection.TRANSACTION_SERIALIZABLE)) + } + + @Test + fun visibilityMethods(): Unit = metaData.run { + assertFalse(othersDeletesAreVisible(ResultSet.TYPE_FORWARD_ONLY)) + assertFalse(othersInsertsAreVisible(ResultSet.TYPE_FORWARD_ONLY)) + assertFalse(othersUpdatesAreVisible(ResultSet.TYPE_FORWARD_ONLY)) + assertFalse(ownDeletesAreVisible(ResultSet.TYPE_FORWARD_ONLY)) + assertFalse(ownInsertsAreVisible(ResultSet.TYPE_FORWARD_ONLY)) + assertFalse(ownUpdatesAreVisible(ResultSet.TYPE_FORWARD_ONLY)) + } + + @Test + fun detectionMethods(): Unit = metaData.run { + assertFalse(deletesAreDetected(ResultSet.TYPE_FORWARD_ONLY)) + assertFalse(insertsAreDetected(ResultSet.TYPE_FORWARD_ONLY)) + assertFalse(updatesAreDetected(ResultSet.TYPE_FORWARD_ONLY)) + } + + @Test + fun additionalSupportMethods(): Unit = metaData.run { + assertFalse(supportsCatalogsInDataManipulation()) + assertFalse(supportsCatalogsInIndexDefinitions()) + assertFalse(supportsCatalogsInPrivilegeDefinitions()) + assertFalse(supportsCatalogsInProcedureCalls()) + assertFalse(supportsCatalogsInTableDefinitions()) + assertFalse(supportsSchemasInDataManipulation()) + assertFalse(supportsSchemasInProcedureCalls()) + assertFalse(supportsSchemasInTableDefinitions()) + assertFalse(supportsSchemasInIndexDefinitions()) + assertFalse(supportsSchemasInPrivilegeDefinitions()) + } + + @Test + fun additionalProperties(): Unit = metaData.run { + assertTrue(supportsDataDefinitionAndDataManipulationTransactions()) + assertFalse(supportsDataManipulationTransactionsOnly()) + assertFalse(supportsIntegrityEnhancementFacility()) + assertTrue(nullPlusNonNullIsNull()) + assertFalse(generatedKeyAlwaysReturned()) + assertEquals(RowIdLifetime.ROWID_UNSUPPORTED, metaData.rowIdLifetime) + } + + @Test + fun resultSetHoldability() { + assertEquals(ResultSet.HOLD_CURSORS_OVER_COMMIT, metaData.resultSetHoldability) + } + + @Test + fun sqlStateType() { + assertEquals(DatabaseMetaData.sqlStateSQL99, metaData.sqlStateType) + } + + @Test + fun databaseVersions(): Unit = metaData.run { + assertEquals(3, databaseMajorVersion) + assertEquals(51, databaseMinorVersion) + } + + @Test + fun unsupportedFeaturesMethods(): Unit = metaData.run { + assertFailsWith { + clientInfoProperties + } + assertFailsWith { + getFunctions(null, null, "%") + } + assertFailsWith { + getFunctionColumns(null, null, "%", "%") + } + assertFailsWith { + getPseudoColumns(null, null, "%", "%") + } + } + + @Test + fun getSchemasWithParameters() { + assertFailsWith { + metaData.getSchemas("catalog", "schema") + } + } + + @Test + fun supportsSelectForUpdate() { + assertFalse(metaData.supportsSelectForUpdate()) + } + + @Test + fun supportsPositionedUpdate() { + assertFalse(metaData.supportsPositionedUpdate()) + } + + @Test + fun supportsPositionedDelete() { + assertFalse(metaData.supportsPositionedDelete()) + } + + @Test + fun getMaxCatalogNameLength() { + assertEquals(0, metaData.maxCatalogNameLength) + } + + @Test + fun supportsConvertWithParameters() { + assertTrue(metaData.supportsConvert(Types.INTEGER, Types.VARCHAR)) + } + + @Suppress("Detekt.CognitiveComplexMethod", "Detekt.LongMethod") + @Test + fun getColumnsWithVariousColumnTypes() { + val tablesColumnNames = arrayOf("TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "TABLE_TYPE", "REMARKS", "TYPE_CAT", + "TYPE_SCHEM", "TYPE_NAME", "SELF_REFERENCING_COL_NAME", "REF_GENERATION") + val tablesCursor = mock { + var callCount = 0 + whenever(it.moveToNext()) doAnswer { + ++callCount == 1 + } + whenever(it.columnCount) doReturn tablesColumnNames.size + whenever(it.columnNames()) doReturn tablesColumnNames + whenever(it.columnIndex(any())) doAnswer { invocation -> + val name = invocation.getArgument(0) + tablesColumnNames.indexOf(name) + } + whenever(it.getString(any())) doAnswer { invocation -> + when (tablesColumnNames.getOrNull(invocation.getArgument(0))) { + "TABLE_NAME" -> "test_table" + else -> null + } + } + whenever(it.isClosed()) doReturn false + } + val pragmaColumnNames = arrayOf("cid", "name", "type", "notnull", "dflt_value", "pk") + val pragmaCursor = mock { + var callCount = 0 + val columns = listOf( + listOf(0, "int_col", "INTEGER", 0, null, 0), + listOf(1, "text_col", "TEXT", 0, null, 0), + listOf(2, "real_col", "REAL", 0, null, 0), + listOf(3, "blob_col", "BLOB", 0, null, 0), + listOf(4, "numeric_col", "NUMERIC", 0, null, 0), + listOf(5, "varchar_col", "VARCHAR", 0, null, 0), + listOf(6, "smallint_col", "SMALLINT", 0, null, 0), + listOf(7, "bigint_col", "BIGINT", 0, null, 0), + listOf(8, "double_col", "DOUBLE", 0, null, 0), + listOf(9, "float_col", "FLOAT", 0, null, 0), + listOf(10, "decimal_col", "DECIMAL", 0, null, 0), + listOf(11, "char_col", "CHAR", 0, null, 0), + listOf(12, "clob_col", "CLOB", 0, null, 0), + listOf(13, "null_col", "NULL", 0, null, 0), + listOf(14, "unknown_col", "UNKNOWN", 0, null, 0), + listOf(15, "mediumint_col", "MEDIUMINT", 0, null, 0), + listOf(16, "doubleprecision_col", "DOUBLE PRECISION", 0, null, 0), + listOf(17, "character_col", "CHARACTER", 0, null, 0) + ) + whenever(it.moveToNext()) doAnswer { + ++callCount <= columns.size + } + whenever(it.columnCount) doReturn pragmaColumnNames.size + whenever(it.columnNames()) doReturn pragmaColumnNames + whenever(it.columnIndex(any())) doAnswer { invocation -> + pragmaColumnNames.indexOf(invocation.getArgument(0)) + } + whenever(it.getString(any())) doAnswer { invocation -> + if (callCount > 0 && callCount <= columns.size) { + columns[callCount - 1].getOrNull(invocation.getArgument(0))?.toString() + } else { + null + } + } + whenever(it.getInt(any())) doAnswer { invocation -> + if (callCount > 0 && callCount <= columns.size) { + columns[callCount - 1].getOrNull(invocation.getArgument(0)) as? Int ?: 0 + } else { + 0 + } + } + whenever(it.position()) doAnswer { + callCount - 1 + } + whenever(it.type(any())) doAnswer { invocation -> + when (invocation.getArgument(0)) { + 0, 3, 5 -> ColumnType.INTEGER + 1, 2 -> ColumnType.STRING + 4 -> ColumnType.NULL + else -> ColumnType.STRING + } + } + whenever(it.isNull(any())) doAnswer { invocation -> + if (callCount > 0 && callCount <= columns.size) { + columns[callCount - 1].getOrNull(invocation.getArgument(0)) == null + } else { + true + } + } + whenever(it.isClosed()) doReturn false + } + val unionCursor = mock { + whenever(it.moveToNext()) doReturn false + whenever(it.columnCount) doReturn 24 + whenever(it.columnNames()) doReturn arrayOf( + "TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "COLUMN_NAME", "DATA_TYPE", "TYPE_NAME", + "COLUMN_SIZE", "BUFFER_LENGTH", "DECIMAL_DIGITS", "NUM_PREC_RADIX", "NULLABLE", + "REMARKS", "COLUMN_DEF", "SQL_DATA_TYPE", "SQL_DATETIME_SUB", "CHAR_OCTET_LENGTH", + "ORDINAL_POSITION", "IS_NULLABLE", "SCOPE_CATALOG", "SCOPE_SCHEMA", "SCOPE_TABLE", + "SOURCE_DATA_TYPE", "IS_AUTOINCREMENT", "IS_GENERATEDCOLUMN" + ) + whenever(it.isClosed()) doReturn false + } + whenever(mockDatabase.query(any(), any>())) doAnswer { + val sql = it.getArgument(0) + when { + sql.contains("sqlite_master") && sql.contains("type IN ('table','view')") -> tablesCursor + sql.startsWith("PRAGMA table_info") -> pragmaCursor + sql.contains("UNION ALL") -> unionCursor + else -> mockCursor + } + } + metaData.getColumns(null, null, "%", "%").use { + assertNotNull(it) + } + } + + @Suppress("Detekt.CognitiveComplexMethod", "Detekt.LongMethod") + @Test + fun getPrimaryKeysWithMultipleKeys() { + val pragmaColumnNames = arrayOf("cid", "name", "type", "notnull", "dflt_value", "pk") + val pragmaCursor = mock { + var callCount = 0 + val columns = listOf( + listOf(0, "id", "INTEGER", 1, null, 1), + listOf(1, "sub_id", "INTEGER", 1, null, 2), + listOf(2, "name", "TEXT", 0, null, 0) + ) + whenever(it.moveToNext()) doAnswer { + ++callCount <= columns.size + } + whenever(it.columnCount) doReturn pragmaColumnNames.size + whenever(it.columnNames()) doReturn pragmaColumnNames + whenever(it.columnIndex(any())) doAnswer { invocation -> + pragmaColumnNames.indexOf(invocation.getArgument(0)) + } + whenever(it.getString(any())) doAnswer { invocation -> + if (callCount > 0 && callCount <= columns.size) { + columns[callCount - 1].getOrNull(invocation.getArgument(0))?.toString() + } else { + null + } + } + whenever(it.getInt(any())) doAnswer { invocation -> + if (callCount > 0 && callCount <= columns.size) { + columns[callCount - 1].getOrNull(invocation.getArgument(0)) as? Int ?: 0 + } else { + 0 + } + } + whenever(it.position()) doAnswer { + callCount - 1 + } + whenever(it.type(any())) doAnswer { invocation -> + when (invocation.getArgument(0)) { + 0, 3, 5 -> ColumnType.INTEGER + 1, 2 -> ColumnType.STRING + 4 -> ColumnType.NULL + else -> ColumnType.STRING + } + } + whenever(it.isNull(any())) doAnswer { invocation -> + if (callCount > 0 && callCount <= columns.size) { + columns[callCount - 1].getOrNull(invocation.getArgument(0)) == null + } else { + true + } + } + whenever(it.isClosed()) doReturn false + } + val unionCursor = mock { + whenever(it.moveToNext()) doReturn false + whenever(it.columnCount) doReturn 6 + whenever(it.columnNames()) doReturn arrayOf( + "TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "COLUMN_NAME", "KEY_SEQ", "PK_NAME" + ) + whenever(it.isClosed()) doReturn false + } + whenever(mockDatabase.query(any(), any>())) doAnswer { + val sql = it.getArgument(0) + when { + sql.startsWith("PRAGMA table_info") -> pragmaCursor + sql.contains("UNION ALL") -> unionCursor + else -> mockCursor + } + } + metaData.getPrimaryKeys(null, null, "test_table").use { + assertNotNull(it) + } + } + + @Suppress("Detekt.CognitiveComplexMethod", "Detekt.LongMethod") + @Test + fun getColumnsWithIntType() { + val tablesColumnNames = arrayOf("TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "TABLE_TYPE", "REMARKS", "TYPE_CAT", + "TYPE_SCHEM", "TYPE_NAME", "SELF_REFERENCING_COL_NAME", "REF_GENERATION") + val tablesCursor = mock { + var callCount = 0 + whenever(it.moveToNext()) doAnswer { + callCount++ + callCount == 1 + } + whenever(it.columnCount) doReturn tablesColumnNames.size + whenever(it.columnNames()) doReturn tablesColumnNames + whenever(it.columnIndex(any())) doAnswer { invocation -> + tablesColumnNames.indexOf(invocation.getArgument(0)) + } + whenever(it.getString(any())) doAnswer { invocation -> + when (tablesColumnNames.getOrNull(invocation.getArgument(0))) { + "TABLE_NAME" -> "test_table" + else -> null + } + } + whenever(it.isClosed()) doReturn false + } + val pragmaColumnNames = arrayOf("cid", "name", "type", "notnull", "dflt_value", "pk") + val pragmaCursor = mock { + var callCount = 0 + val columns = listOf(listOf(0, "int_type_col", "INT", 1, "42", 1)) + whenever(it.moveToNext()) doAnswer { + ++callCount <= columns.size + } + whenever(it.columnCount) doReturn pragmaColumnNames.size + whenever(it.columnNames()) doReturn pragmaColumnNames + whenever(it.columnIndex(any())) doAnswer { invocation -> + pragmaColumnNames.indexOf(invocation.getArgument(0)) + } + whenever(it.getString(any())) doAnswer { invocation -> + if (callCount > 0 && callCount <= columns.size) { + columns[callCount - 1].getOrNull(invocation.getArgument(0))?.toString() + } else { + null + } + } + whenever(it.getInt(any())) doAnswer { invocation -> + if (callCount > 0 && callCount <= columns.size) { + columns[callCount - 1].getOrNull(invocation.getArgument(0)) as? Int ?: 0 + } else { + 0 + } + } + whenever(it.position()) doAnswer { + callCount - 1 + } + whenever(it.type(any())) doAnswer { invocation -> + when (invocation.getArgument(0)) { + 0, 3, 5 -> ColumnType.INTEGER + 1, 2 -> ColumnType.STRING + 4 -> ColumnType.STRING + else -> ColumnType.STRING + } + } + whenever(it.isClosed()) doReturn false + } + val unionCursor = mock { + whenever(it.moveToNext()) doReturn false + whenever(it.columnCount) doReturn 24 + whenever(it.columnNames()) doReturn arrayOf( + "TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "COLUMN_NAME", "DATA_TYPE", "TYPE_NAME", + "COLUMN_SIZE", "BUFFER_LENGTH", "DECIMAL_DIGITS", "NUM_PREC_RADIX", "NULLABLE", + "REMARKS", "COLUMN_DEF", "SQL_DATA_TYPE", "SQL_DATETIME_SUB", "CHAR_OCTET_LENGTH", + "ORDINAL_POSITION", "IS_NULLABLE", "SCOPE_CATALOG", "SCOPE_SCHEMA", "SCOPE_TABLE", + "SOURCE_DATA_TYPE", "IS_AUTOINCREMENT", "IS_GENERATEDCOLUMN" + ) + whenever(it.isClosed()) doReturn false + } + whenever(mockDatabase.query(any(), any>())) doAnswer { invocation -> + val sql = invocation.getArgument(0) + when { + sql.contains("sqlite_master") && sql.contains("type IN ('table','view')") -> tablesCursor + sql.startsWith("PRAGMA table_info") -> pragmaCursor + sql.contains("UNION ALL") -> unionCursor + else -> mockCursor + } + } + metaData.getColumns(null, null, "%", "%").use { + assertNotNull(it) + } + } + + @Test + fun getColumnsWithEmptyResult() { + val tablesColumnNames = arrayOf("TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "TABLE_TYPE", "REMARKS", "TYPE_CAT", + "TYPE_SCHEM", "TYPE_NAME", "SELF_REFERENCING_COL_NAME", "REF_GENERATION") + val tablesCursor = mock { + whenever(it.moveToNext()) doReturn false + whenever(it.columnCount) doReturn tablesColumnNames.size + whenever(it.columnNames()) doReturn tablesColumnNames + whenever(it.isClosed()) doReturn false + } + val emptyCursor = mock { + whenever(it.moveToNext()) doReturn false + whenever(it.columnCount) doReturn 24 + whenever(it.columnNames()) doReturn arrayOf( + "TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "COLUMN_NAME", "DATA_TYPE", "TYPE_NAME", + "COLUMN_SIZE", "BUFFER_LENGTH", "DECIMAL_DIGITS", "NUM_PREC_RADIX", "NULLABLE", + "REMARKS", "COLUMN_DEF", "SQL_DATA_TYPE", "SQL_DATETIME_SUB", "CHAR_OCTET_LENGTH", + "ORDINAL_POSITION", "IS_NULLABLE", "SCOPE_CATALOG", "SCOPE_SCHEMA", "SCOPE_TABLE", + "SOURCE_DATA_TYPE", "IS_AUTOINCREMENT", "IS_GENERATEDCOLUMN" + ) + whenever(it.isClosed()) doReturn false + } + whenever(mockDatabase.query(any(), any>())) doAnswer { invocation -> + val sql = invocation.getArgument(0) + when { + sql.contains("sqlite_master") && sql.contains("type IN ('table','view')") -> tablesCursor + sql.contains("WHERE 1 = 0") -> emptyCursor + else -> mockCursor + } + } + metaData.getColumns(null, null, "%", "%").use { + assertNotNull(it) + } + } + + @Suppress("Detekt.CognitiveComplexMethod", "Detekt.LongMethod") + @Test + fun getPrimaryKeysWithEmptyResult() { + val pragmaColumnNames = arrayOf("cid", "name", "type", "notnull", "dflt_value", "pk") + val pragmaCursor = mock { + var callCount = 0 + val columns = listOf( + listOf(0, "id", "INTEGER", 0, null, 0), + listOf(1, "name", "TEXT", 0, null, 0) + ) + whenever(it.moveToNext()) doAnswer { + ++callCount <= columns.size + } + whenever(it.columnCount) doReturn pragmaColumnNames.size + whenever(it.columnNames()) doReturn pragmaColumnNames + whenever(it.columnIndex(any())) doAnswer { invocation -> + val name = invocation.getArgument(0) + pragmaColumnNames.indexOf(name) + } + whenever(it.getString(any())) doAnswer { invocation -> + if (callCount > 0 && callCount <= columns.size) { + columns[callCount - 1].getOrNull(invocation.getArgument(0))?.toString() + } else { + null + } + } + whenever(it.getInt(any())) doAnswer { invocation -> + if (callCount > 0 && callCount <= columns.size) { + columns[callCount - 1].getOrNull(invocation.getArgument(0)) as? Int ?: 0 + } else { + 0 + } + } + whenever(it.position()) doAnswer { + callCount - 1 + } + whenever(it.type(any())) doAnswer { invocation -> + val index = invocation.getArgument(0) + when (index) { + 0, 3, 5 -> ColumnType.INTEGER + 1, 2 -> ColumnType.STRING + 4 -> ColumnType.NULL + else -> ColumnType.STRING + } + } + whenever(it.isNull(any())) doAnswer { invocation -> + if (callCount > 0 && callCount <= columns.size) { + columns[callCount - 1].getOrNull(invocation.getArgument(0)) == null + } else { + true + } + } + whenever(it.isClosed()) doReturn false + } + val emptyCursor = mock { + whenever(it.moveToNext()) doReturn false + whenever(it.columnCount) doReturn 6 + whenever(it.columnNames()) doReturn arrayOf( + "TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "COLUMN_NAME", "KEY_SEQ", "PK_NAME" + ) + whenever(it.isClosed()) doReturn false + } + whenever(mockDatabase.query(any(), any>())) doAnswer { + val sql = it.getArgument(0) + when { + sql.startsWith("PRAGMA table_info") -> pragmaCursor + sql.contains("WHERE 1 = 0") -> emptyCursor + else -> mockCursor + } + } + metaData.getPrimaryKeys(null, null, "test_table").use { + assertNotNull(it) + } + } + + @Test + fun getTablesWithNullPattern() { + val cursor = mock { + whenever(it.moveToNext()) doReturn false + whenever(it.columnCount) doReturn 10 + whenever(it.columnNames()) doReturn arrayOf("TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "TABLE_TYPE", + "REMARKS", "TYPE_CAT", "TYPE_SCHEM", "TYPE_NAME", "SELF_REFERENCING_COL_NAME", "REF_GENERATION") + whenever(it.isClosed()) doReturn false + } + whenever(mockDatabase.query(any(), any>())) doAnswer { invocation -> + val sql = invocation.getArgument(0) + when { + sql.contains("sqlite_master") -> cursor + else -> mockCursor + } + } + metaData.getTables(null, null, null, null).use { + assertNotNull(it) + } + } + + @Test + fun getTablesWithSpecificPattern() { + val cursor = mock { + whenever(it.moveToNext()) doReturn false + whenever(it.columnCount) doReturn 10 + whenever(it.columnNames()) doReturn arrayOf("TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "TABLE_TYPE", + "REMARKS", "TYPE_CAT", "TYPE_SCHEM", "TYPE_NAME", "SELF_REFERENCING_COL_NAME", "REF_GENERATION") + whenever(it.isClosed()) doReturn false + } + whenever(mockDatabase.query(any(), any>())) doAnswer { invocation -> + val sql = invocation.getArgument(0) + when { + sql.contains("sqlite_master") && sql.contains("AND name GLOB") -> cursor + else -> mockCursor + } + } + metaData.getTables(null, null, "test_%", null).use { + assertNotNull(it) + } + } + + @Test + fun getIndexInfoWithUniqueFlag() { + val cursor = mock { + whenever(it.moveToNext()) doReturn false + whenever(it.columnCount) doReturn 13 + whenever(it.columnNames()) doReturn arrayOf("TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "NON_UNIQUE", + "INDEX_QUALIFIER", "INDEX_NAME", "TYPE", "ORDINAL_POSITION", "COLUMN_NAME", "ASC_OR_DESC", + "CARDINALITY", "PAGES", "FILTER_CONDITION") + whenever(it.isClosed()) doReturn false + } + whenever(mockDatabase.query(any(), any>())) doAnswer { invocation -> + val sql = invocation.getArgument(0) + when { + sql.contains("AND \"unique\" = 1") -> cursor + else -> mockCursor + } + } + metaData.getIndexInfo(null, null, "test_table", unique = true, approximate = false).use { + assertNotNull(it) + } + } +} diff --git a/selekt-jdbc/src/test/kotlin/com/bloomberg/selekt/jdbc/result/JdbcResultSetMetaDataTest.kt b/selekt-jdbc/src/test/kotlin/com/bloomberg/selekt/jdbc/result/JdbcResultSetMetaDataTest.kt new file mode 100644 index 0000000000..68671f5750 --- /dev/null +++ b/selekt-jdbc/src/test/kotlin/com/bloomberg/selekt/jdbc/result/JdbcResultSetMetaDataTest.kt @@ -0,0 +1,274 @@ +/* + * Copyright 2026 Bloomberg Finance L.P. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.bloomberg.selekt.jdbc.result + +import com.bloomberg.selekt.ColumnType +import com.bloomberg.selekt.ICursor +import java.sql.ResultSetMetaData +import java.sql.SQLException +import java.sql.Types +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertFalse +import kotlin.test.assertSame +import kotlin.test.assertTrue +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.mockito.kotlin.doReturn +import org.mockito.kotlin.mock +import org.mockito.kotlin.whenever + +internal class JdbcResultSetMetaDataTest { + private lateinit var mockCursor: ICursor + private lateinit var metaData: JdbcResultSetMetaData + + @BeforeEach + fun setUp() { + mockCursor = mock { + whenever(it.columnCount) doReturn 4 + whenever(it.columnNames()) doReturn arrayOf("id", "name", "age", "balance") + whenever(it.columnName(0)) doReturn "id" + whenever(it.type(0)) doReturn ColumnType.INTEGER + whenever(it.columnName(1)) doReturn "name" + whenever(it.type(1)) doReturn ColumnType.STRING + whenever(it.columnName(2)) doReturn "age" + whenever(it.type(2)) doReturn ColumnType.INTEGER + whenever(it.columnName(3)) doReturn "balance" + whenever(it.type(3)) doReturn ColumnType.FLOAT + } + metaData = JdbcResultSetMetaData(mockCursor) + } + + @Test + fun getColumnCount() { + assertEquals(4, metaData.columnCount) + } + + @Test + fun getColumnName(): Unit = metaData.run { + assertEquals("id", getColumnName(1)) + assertEquals("name", getColumnName(2)) + assertEquals("age", getColumnName(3)) + assertEquals("balance", getColumnName(4)) + } + + @Test + fun getColumnLabel(): Unit = metaData.run { + assertEquals("id", getColumnLabel(1)) + assertEquals("name", getColumnLabel(2)) + assertEquals("age", getColumnLabel(3)) + assertEquals("balance", getColumnLabel(4)) + } + + @Test + fun getColumnType(): Unit = metaData.run { + assertEquals(Types.BIGINT, getColumnType(1)) + assertEquals(Types.VARCHAR, getColumnType(2)) + assertEquals(Types.BIGINT, getColumnType(3)) + assertEquals(Types.DOUBLE, getColumnType(4)) + } + + @Test + fun getColumnTypeName(): Unit = metaData.run { + assertEquals("BIGINT", getColumnTypeName(1)) + assertEquals("VARCHAR", getColumnTypeName(2)) + assertEquals("BIGINT", getColumnTypeName(3)) + assertEquals("DOUBLE", getColumnTypeName(4)) + } + + @Test + fun getColumnClassName(): Unit = metaData.run { + assertEquals(Long::class.java.name, getColumnClassName(1)) + assertEquals(String::class.java.name, getColumnClassName(2)) + assertEquals(Long::class.java.name, getColumnClassName(3)) + assertEquals(Double::class.java.name, getColumnClassName(4)) + } + + @Test + fun getPrecision(): Unit = metaData.run { + assertEquals(19, getPrecision(1)) + assertEquals(0, getPrecision(2)) + assertEquals(19, getPrecision(3)) + assertEquals(15, getPrecision(4)) + } + + @Test + fun getScale(): Unit = metaData.run { + assertEquals(0, getScale(1)) + assertEquals(0, getScale(2)) + assertEquals(0, getScale(3)) + assertEquals(15, getScale(4)) + } + + @Test + fun getDisplaySize(): Unit = metaData.run { + assertEquals(20, getColumnDisplaySize(1)) + assertEquals(Integer.MAX_VALUE, getColumnDisplaySize(2)) + assertEquals(20, getColumnDisplaySize(3)) + assertEquals(24, getColumnDisplaySize(4)) + } + + @Test + fun isNullable(): Unit = metaData.run { + assertEquals(ResultSetMetaData.columnNullableUnknown, isNullable(1)) + assertEquals(ResultSetMetaData.columnNullableUnknown, isNullable(2)) + assertEquals(ResultSetMetaData.columnNullableUnknown, isNullable(3)) + assertEquals(ResultSetMetaData.columnNullableUnknown, isNullable(4)) + } + + @Test + fun columnProperties(): Unit = metaData.run { + assertFalse(isAutoIncrement(1)) + assertFalse(isCaseSensitive(1)) + assertTrue(isSearchable(1)) + assertFalse(isCurrency(1)) + assertTrue(isReadOnly(1)) + assertFalse(isWritable(1)) + assertFalse(isDefinitelyWritable(1)) + assertTrue(isSigned(1)) + } + + @Test + fun stringColumnProperties(): Unit = metaData.run { + assertTrue(isCaseSensitive(2)) + assertFalse(isSigned(2)) + } + + @Test + fun floatColumnProperties(): Unit = metaData.run { + assertTrue(isSigned(4)) + assertFalse(isCurrency(4)) + } + + @Test + fun invalidColumnIndex(): Unit = metaData.run { + assertFailsWith { + getColumnName(0) + } + assertFailsWith { + getColumnType(5) + } + assertFailsWith { + getColumnName(-1) + } + } + + @Test + fun schemaAndCatalogInfo(): Unit = metaData.run { + assertTrue(getSchemaName(1).isEmpty()) + assertTrue(getCatalogName(1).isEmpty()) + assertTrue(getTableName(1).isEmpty()) + } + + @Test + fun wrapperInterface(): Unit = metaData.run { + assertTrue(isWrapperFor(JdbcResultSetMetaData::class.java)) + assertFalse(isWrapperFor(String::class.java)) + assertSame(this, unwrap(JdbcResultSetMetaData::class.java)) + assertFailsWith { + unwrap(String::class.java) + } + } + + @Test + fun nullColumnType(): Unit = metaData.run { + whenever(mockCursor.type(0)) doReturn ColumnType.NULL + assertEquals(Types.NULL, getColumnType(1)) + assertEquals("NULL", getColumnTypeName(1)) + } + + @Test + fun blobColumnType(): Unit = metaData.run { + whenever(mockCursor.type(0)) doReturn ColumnType.BLOB + assertEquals(Types.VARBINARY, getColumnType(1)) + assertEquals("VARBINARY", getColumnTypeName(1)) + assertEquals(ByteArray::class.java.name, getColumnClassName(1)) + } + + @Test + fun columnIndexValidation(): Unit = metaData.run { + val invalidIndices = listOf(0, -1, 5) + val validMethods = listOf<(Int) -> Any>( + ::getColumnName, + ::getColumnLabel, + ::getColumnType, + ::getColumnTypeName, + ::getColumnClassName, + ::getPrecision, + ::getScale, + ::getColumnDisplaySize, + ::isNullable, + ::isAutoIncrement, + ::isCaseSensitive, + ::isSearchable, + ::isCurrency, + ::isReadOnly, + ::isWritable, + ::isDefinitelyWritable, + ::isSigned, + ::getSchemaName, + ::getCatalogName, + ::getTableName + ) + for (invalidIndex in invalidIndices) { + for (method in validMethods) { + assertFailsWith("Method should throw for index $invalidIndex") { + method(invalidIndex) + } + } + } + } + + @Test + fun allColumnTypes() { + val typeTests = listOf( + ColumnType.INTEGER to Types.BIGINT, + ColumnType.FLOAT to Types.DOUBLE, + ColumnType.STRING to Types.VARCHAR, + ColumnType.BLOB to Types.VARBINARY, + ColumnType.NULL to Types.NULL + ) + val testCursor = mock { + whenever(it.columnCount) doReturn 5 + whenever(it.columnNames()) doReturn arrayOf("col1", "col2", "col3", "col4", "col5") + whenever(it.position()) doReturn 0 + whenever(it.isBeforeFirst()) doReturn false + whenever(it.isAfterLast()) doReturn false + typeTests.forEachIndexed { index, (selektType, _) -> + whenever(it.type(index)) doReturn selektType + } + } + val testMetaData = JdbcResultSetMetaData(testCursor) + typeTests.forEachIndexed { index, (_, expectedJdbcType) -> + assertEquals(expectedJdbcType, testMetaData.getColumnType(index + 1)) + } + } + + @Test + fun emptyResultSet() { + val emptyCursor = mock { + whenever(it.columnCount) doReturn 0 + whenever(it.columnNames()) doReturn emptyArray() + } + JdbcResultSetMetaData(emptyCursor).run { + assertEquals(0, columnCount) + assertFailsWith { + getColumnName(1) + } + } + } +} diff --git a/selekt-jdbc/src/test/kotlin/com/bloomberg/selekt/jdbc/result/JdbcResultSetTest.kt b/selekt-jdbc/src/test/kotlin/com/bloomberg/selekt/jdbc/result/JdbcResultSetTest.kt new file mode 100644 index 0000000000..e308c33b31 --- /dev/null +++ b/selekt-jdbc/src/test/kotlin/com/bloomberg/selekt/jdbc/result/JdbcResultSetTest.kt @@ -0,0 +1,1320 @@ +/* + * Copyright 2026 Bloomberg Finance L.P. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.bloomberg.selekt.jdbc.result + +import com.bloomberg.selekt.ColumnType +import com.bloomberg.selekt.ICursor +import com.bloomberg.selekt.jdbc.statement.JdbcStatement +import java.io.InputStream +import java.io.Reader +import java.math.BigDecimal +import java.sql.Blob +import java.sql.Clob +import java.sql.Date +import java.sql.NClob +import java.sql.Ref +import java.sql.ResultSet +import java.sql.RowId +import java.sql.SQLException +import java.sql.SQLXML +import java.sql.Time +import java.sql.Timestamp +import java.time.LocalDate +import java.time.LocalDateTime +import java.time.LocalTime +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertFalse +import kotlin.test.assertNotNull +import kotlin.test.assertNull +import kotlin.test.assertSame +import kotlin.test.assertTrue +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.mockito.kotlin.any +import org.mockito.kotlin.doReturn +import org.mockito.kotlin.doThrow +import org.mockito.kotlin.mock +import org.mockito.kotlin.whenever + +internal class JdbcResultSetTest { + private lateinit var mockCursor: ICursor + private lateinit var mockStatement: JdbcStatement + private lateinit var resultSet: JdbcResultSet + + @BeforeEach + fun setUp() { + mockCursor = mock { + whenever(it.columnCount) doReturn 4 + whenever(it.columnNames()) doReturn arrayOf("id", "name", "age", "balance") + + whenever(it.columnName(0)) doReturn "id" + whenever(it.columnName(1)) doReturn "name" + whenever(it.columnName(2)) doReturn "age" + whenever(it.columnName(3)) doReturn "balance" + + whenever(it.columnIndex(any())) doReturn -1 + whenever(it.columnIndex("id")) doReturn 0 + whenever(it.columnIndex("name")) doReturn 1 + whenever(it.columnIndex("age")) doReturn 2 + whenever(it.columnIndex("balance")) doReturn 3 + + whenever(it.type(0)) doReturn ColumnType.INTEGER + whenever(it.type(1)) doReturn ColumnType.STRING + whenever(it.type(2)) doReturn ColumnType.INTEGER + whenever(it.type(3)) doReturn ColumnType.FLOAT + } + mockStatement = mock() + resultSet = JdbcResultSet(mockCursor, mockStatement) + } + + @Test + fun next() { + whenever(mockCursor.moveToNext()).doReturn(true, true, false) + resultSet.run { + assertTrue(next()) + assertTrue(next()) + assertFalse(next()) + } + } + + @Test + fun getIntByIndex() { + mockCursor.run { + whenever(isNull(0)) doReturn false + whenever(getInt(0)) doReturn 42 + } + assertEquals(42, resultSet.getInt(1)) + } + + @Test + fun getIntByName() { + mockCursor.run { + whenever(isNull(0)) doReturn false + whenever(mockCursor.getInt(0)) doReturn 42 + } + assertEquals(42, resultSet.getInt("id")) + } + + @Test + fun getStringByIndex() { + mockCursor.run { + whenever(isNull(1)) doReturn false + whenever(getString(1)) doReturn "John Doe" + } + assertEquals("John Doe", resultSet.getString(2)) + } + + @Test + fun getStringByName() { + mockCursor.run { + whenever(isNull(1)) doReturn false + whenever(getString(1)) doReturn "John Doe" + } + assertEquals("John Doe", resultSet.getString("name")) + } + + @Test + fun getDoubleByIndex() { + mockCursor.run { + whenever(isNull(3)) doReturn false + whenever(getDouble(3)) doReturn 123.45 + } + assertEquals(123.45, resultSet.getDouble(4), 0.001) + } + + @Test + fun getBooleanFromInteger() { + mockCursor.run { + whenever(isNull(0)) doReturn false + whenever(getLong(0)) doReturn 1L + } + assertTrue(resultSet.getBoolean(1)) + whenever(mockCursor.getLong(0)) doReturn 0L + assertFalse(resultSet.getBoolean(1)) + } + + @Test + fun getBooleanFromString() { + mockCursor.run { + whenever(isNull(1)) doReturn false + whenever(getLong(1)) doThrow IllegalArgumentException("Not a number") + whenever(getString(1)) doReturn "true" + } + assertTrue(resultSet.getBoolean(2)) + whenever(mockCursor.getString(1)) doReturn "false" + assertFalse(resultSet.getBoolean(2)) + } + + @Test + fun getBytes() { + val testBytes = byteArrayOf(1, 2, 3, 4) + mockCursor.run { + whenever(isNull(1)) doReturn false + whenever(getBlob(1)) doReturn testBytes + } + assertEquals(testBytes.toList(), resultSet.getBytes(2)?.toList()) + } + + @Test + fun getDate() { + mockCursor.run { + whenever(isNull(1)) doReturn false + whenever(getString(1)) doReturn "2025-12-25" + } + assertEquals(Date.valueOf("2025-12-25"), resultSet.getDate(2)) + } + + @Test + fun getTime() { + mockCursor.run { + whenever(isNull(1)) doReturn false + whenever(getString(1)) doReturn "10:30:45" + } + assertEquals(Time.valueOf("10:30:45"), resultSet.getTime(2)) + } + + @Test + fun getTimestamp() { + mockCursor.run { + whenever(isNull(1)) doReturn false + whenever(getString(1)) doReturn "2025-12-25 10:30:45" + } + assertEquals(Timestamp.valueOf("2025-12-25 10:30:45"), resultSet.getTimestamp(2)) + } + + @Test + fun getBigDecimal() { + mockCursor.run { + whenever(isNull(3)) doReturn false + whenever(getDouble(3)) doReturn 123.45 + } + assertEquals(BigDecimal.valueOf(123.45), resultSet.getBigDecimal(4)) + } + + @Test + fun getObject() { + mockCursor.run { + whenever(isNull(0)) doReturn false + whenever(getLong(0)) doReturn 42L + } + assertEquals(42L, resultSet.getObject(1)) + mockCursor.run { + whenever(isNull(1)) doReturn false + whenever(getString(1)) doReturn "test" + } + assertEquals("test", resultSet.getObject(2)) + } + + @Test + fun wasNull() { + whenever(mockCursor.isNull(0)) doReturn true + resultSet.run { + getInt(1) + assertTrue(wasNull()) + } + mockCursor.run { + whenever(isNull(1)) doReturn false + whenever(getString(1)) doReturn "test" + } + resultSet.run { + getString(2) + assertFalse(wasNull()) + } + } + + @Test + fun findColumn(): Unit = resultSet.run { + assertEquals(1, findColumn("id")) + assertEquals(2, findColumn("name")) + assertEquals(3, findColumn("age")) + assertEquals(4, findColumn("balance")) + assertFailsWith { + findColumn("nonexistent") + } + } + + @Test + fun invalidColumnIndex() { + assertFailsWith { + resultSet.getInt(0) + } + assertFailsWith { + resultSet.getString(5) + } + } + + @Test + fun metaData() { + assertEquals(4, resultSet.metaData.columnCount) + } + + @Test + fun getStatement() { + assertEquals(mockStatement, resultSet.statement) + } + + @Test + fun closure() { + whenever(mockCursor.isClosed()) doReturn false + assertFalse(resultSet.isClosed) + whenever(mockCursor.isClosed()) doReturn true + resultSet.close() + assertTrue(resultSet.isClosed) + assertFailsWith { + resultSet.next() + } + assertFailsWith { + resultSet.getString(1) + } + } + + @Test + fun cursorMovement() { + assertEquals(ResultSet.FETCH_FORWARD, resultSet.fetchDirection) + assertFailsWith { + resultSet.fetchDirection = ResultSet.FETCH_REVERSE + } + } + + @Test + fun unsupportedOperations(): Unit = resultSet.run { + assertFailsWith { + previous() + } + assertFailsWith { + first() + } + assertFailsWith { + last() + } + assertFailsWith { + absolute(1) + } + assertFailsWith { + relative(1) + } + assertFailsWith { + updateString(1, "test") + } + assertFailsWith { + updateString("label", "test") + } + assertFailsWith { + insertRow() + } + assertFailsWith { + deleteRow() + } + } + + @Test + fun resultSetProperties(): Unit = resultSet.run { + assertEquals(ResultSet.TYPE_FORWARD_ONLY, type) + assertEquals(ResultSet.CONCUR_READ_ONLY, concurrency) + assertEquals(ResultSet.CLOSE_CURSORS_AT_COMMIT, holdability) + } + + @Test + fun warnings(): Unit = resultSet.run { + assertNull(warnings) + clearWarnings() + } + + @Test + fun rowOperations() { + mockCursor.run { + whenever(position()) doReturn 0 + whenever(isBeforeFirst()) doReturn true + whenever(isFirst()) doReturn false + whenever(isLast()) doReturn false + whenever(isAfterLast()) doReturn false + } + resultSet.run { + assertEquals(0, row) + assertFalse(isFirst()) + assertFalse(isLast()) + assertTrue(isBeforeFirst()) + assertFalse(isAfterLast()) + } + } + + @Test + fun fetchSize(): Unit = resultSet.run { + assertEquals(0, fetchSize) + fetchSize = 100 + assertEquals(100, fetchSize) + assertFailsWith { + fetchSize = -1 + } + } + + @Test + fun wrapperInterface(): Unit = resultSet.run { + assertTrue(isWrapperFor(JdbcResultSet::class.java)) + assertFalse(isWrapperFor(String::class.java)) + assertEquals(resultSet, unwrap(JdbcResultSet::class.java)) + assertFailsWith { + resultSet.unwrap(String::class.java) + } + } + + @Test + fun cursorName() { + assertFailsWith { + resultSet.cursorName + } + } + + @Test + fun nullValues() { + whenever(mockCursor.isNull(0)) doReturn true + resultSet.run { + assertEquals(0, getInt(1)) + assertTrue(wasNull()) + assertNull(getString(1)) + assertTrue(wasNull()) + assertNull(getObject(1)) + assertTrue(wasNull()) + } + } + + @Test + fun typeConversions() { + mockCursor.run { + whenever(isNull(0)) doReturn false + whenever(getInt(0)) doReturn 42 + whenever(getShort(0)) doReturn 42.toShort() + whenever(getFloat(0)) doReturn 42.0f + } + resultSet.run { + assertEquals(42.toByte(), getByte(1)) + assertEquals(42.toShort(), getShort(1)) + assertEquals(42.toFloat(), getFloat(1), 0.001f) + } + mockCursor.run { + whenever(isNull(1)) doReturn false + whenever(getString(1)) doReturn "123" + } + assertEquals(123, resultSet.getInt(2)) + whenever(mockCursor.getString(1)) doReturn "123.45" + assertEquals(123.45, resultSet.getDouble(2), 0.001) + } + + @Test + fun getLongByIndex() { + mockCursor.run { + whenever(isNull(0)) doReturn false + whenever(getLong(0)) doReturn 123_456_789L + } + assertEquals(123_456_789L, resultSet.getLong(1)) + } + + @Test + fun getLongByName() { + mockCursor.run { + whenever(isNull(0)) doReturn false + whenever(getLong(0)) doReturn 987_654_321L + } + assertEquals(987_654_321L, resultSet.getLong("id")) + } + + @Test + fun getFloatByName() { + mockCursor.run { + whenever(isNull(3)) doReturn false + whenever(getFloat(3)) doReturn 99.5f + } + assertEquals(99.5f, resultSet.getFloat("balance"), 0.001f) + } + + @Test + fun getDoubleByName() { + mockCursor.run { + whenever(isNull(3)) doReturn false + whenever(getDouble(3)) doReturn 555.55 + } + assertEquals(555.55, resultSet.getDouble("balance"), 0.001) + } + + @Test + fun getByteByName() { + mockCursor.run { + whenever(isNull(2)) doReturn false + whenever(getInt(2)) doReturn 25 + } + assertEquals(25.toByte(), resultSet.getByte("age")) + } + + @Test + fun getShortByName() { + mockCursor.run { + whenever(isNull(2)) doReturn false + whenever(getShort(2)) doReturn 300.toShort() + } + assertEquals(300.toShort(), resultSet.getShort("age")) + } + + @Test + fun getBigDecimalByName() { + mockCursor.run { + whenever(isNull(3)) doReturn false + whenever(getDouble(3)) doReturn 999.99 + } + assertEquals(BigDecimal.valueOf(999.99), resultSet.getBigDecimal("balance")) + } + + @Test + fun getBigDecimalNullValue() { + whenever(mockCursor.isNull(3)) doReturn true + resultSet.run { + assertNull(getBigDecimal(4)) + assertTrue(wasNull()) + } + } + + @Test + fun getBigDecimalWithScale() { + mockCursor.run { + whenever(isNull(3)) doReturn false + whenever(getDouble(3)) doReturn 123.456789 + } + @Suppress("DEPRECATION") + val result = resultSet.getBigDecimal(4, 2) + assertEquals(BigDecimal.valueOf(123.46), result) + } + + @Test + fun getBigDecimalWithScaleByName() { + mockCursor.run { + whenever(isNull(3)) doReturn false + whenever(getDouble(3)) doReturn 123.456789 + } + @Suppress("DEPRECATION") + val result = resultSet.getBigDecimal("balance", 2) + assertEquals(BigDecimal.valueOf(123.46), result) + } + + @Test + fun getBytesByName() { + val bytes = byteArrayOf(5, 6, 7, 8) + mockCursor.run { + whenever(isNull(1)) doReturn false + whenever(getBlob(1)) doReturn bytes + } + assertEquals(bytes.toList(), resultSet.getBytes("name")?.toList()) + } + + @Test + fun getDateByName() { + mockCursor.run { + whenever(isNull(1)) doReturn false + whenever(getString(1)) doReturn "2024-01-15" + } + assertEquals(Date.valueOf("2024-01-15"), resultSet.getDate("name")) + } + + @Test + fun getDateWithCalendar() { + mockCursor.run { + whenever(isNull(1)) doReturn false + whenever(getString(1)) doReturn "2024-01-15" + } + assertEquals(Date.valueOf("2024-01-15"), resultSet.getDate(2, null)) + assertEquals(Date.valueOf("2024-01-15"), resultSet.getDate("name", null)) + } + + @Test + fun getTimeByName() { + mockCursor.run { + whenever(isNull(1)) doReturn false + whenever(getString(1)) doReturn "14:30:00" + } + assertEquals(Time.valueOf("14:30:00"), resultSet.getTime("name")) + } + + @Test + fun getTimeWithCalendar() { + mockCursor.run { + whenever(isNull(1)) doReturn false + whenever(getString(1)) doReturn "14:30:00" + } + assertEquals(Time.valueOf("14:30:00"), resultSet.getTime(2, null)) + assertEquals(Time.valueOf("14:30:00"), resultSet.getTime("name", null)) + } + + @Test + fun getTimestampByName() { + mockCursor.run { + whenever(isNull(1)) doReturn false + whenever(getString(1)) doReturn "2024-01-15 14:30:00" + } + assertEquals(Timestamp.valueOf("2024-01-15 14:30:00"), resultSet.getTimestamp("name")) + } + + @Test + fun getTimestampWithCalendar() { + mockCursor.run { + whenever(isNull(1)) doReturn false + whenever(getString(1)) doReturn "2024-01-15 14:30:00" + } + assertEquals(Timestamp.valueOf("2024-01-15 14:30:00"), resultSet.getTimestamp(2, null)) + assertEquals(Timestamp.valueOf("2024-01-15 14:30:00"), resultSet.getTimestamp("name", null)) + } + + @Test + fun getDateNull() { + whenever(mockCursor.isNull(1)) doReturn true + assertNull(resultSet.getDate(2)) + assertTrue(resultSet.wasNull()) + } + + @Test + fun getTimeNull() { + whenever(mockCursor.isNull(1)) doReturn true + assertNull(resultSet.getTime(2)) + assertTrue(resultSet.wasNull()) + } + + @Test + fun getTimestampNull() { + whenever(mockCursor.isNull(1)) doReturn true + assertNull(resultSet.getTimestamp(2)) + assertTrue(resultSet.wasNull()) + } + + @Test + fun getAsciiStream() { + mockCursor.run { + whenever(isNull(1)) doReturn false + whenever(getString(1)) doReturn "test data" + } + resultSet.getAsciiStream(2).use { + assertEquals("test data", it?.readBytes()?.toString(Charsets.US_ASCII)) + } + } + + @Test + fun getAsciiStreamByName() { + mockCursor.run { + whenever(isNull(1)) doReturn false + whenever(getString(1)) doReturn "named test" + } + resultSet.getAsciiStream("name").use { + assertEquals("named test", it?.readBytes()?.toString(Charsets.US_ASCII)) + } + } + + @Test + fun getUnicodeStream() { + mockCursor.run { + whenever(isNull(1)) doReturn false + whenever(getString(1)) doReturn "unicode" + } + @Suppress("DEPRECATION") + resultSet.getUnicodeStream(2).use { + assertNotNull(it) + } + } + + @Test + fun getUnicodeStreamByName() { + mockCursor.run { + whenever(isNull(1)) doReturn false + whenever(getString(1)) doReturn "unicode" + } + @Suppress("DEPRECATION") + resultSet.getUnicodeStream("name").use { + assertNotNull(it) + } + } + + @Test + fun getBinaryStream() { + val bytes = byteArrayOf(1, 2, 3) + mockCursor.run { + whenever(isNull(1)) doReturn false + whenever(getBlob(1)) doReturn bytes + } + resultSet.getBinaryStream(2).use { + assertEquals(bytes.toList(), it?.readBytes()?.toList()) + } + } + + @Test + fun getBinaryStreamByName() { + val bytes = byteArrayOf(4, 5, 6) + mockCursor.run { + whenever(isNull(1)) doReturn false + whenever(getBlob(1)) doReturn bytes + } + resultSet.getBinaryStream("name").use { + assertEquals(bytes.toList(), it?.readBytes()?.toList()) + } + } + + @Test + fun getCharacterStream() { + mockCursor.run { + whenever(isNull(1)) doReturn false + whenever(getString(1)) doReturn "character stream" + } + resultSet.getCharacterStream(2).use { + assertEquals("character stream", it?.readText()) + } + } + + @Test + fun getCharacterStreamByName() { + mockCursor.run { + whenever(isNull(1)) doReturn false + whenever(getString(1)) doReturn "named stream" + } + resultSet.getCharacterStream("name").use { + assertEquals("named stream", it?.readText()) + } + } + + @Test + fun getNCharacterStream() { + mockCursor.run { + whenever(isNull(1)) doReturn false + whenever(getString(1)) doReturn "nchar stream" + } + resultSet.getNCharacterStream(2).use { + assertEquals("nchar stream", it?.readText()) + } + } + + @Test + fun getNCharacterStreamByName() { + mockCursor.run { + whenever(isNull(1)) doReturn false + whenever(getString(1)) doReturn "nchar named" + } + resultSet.getNCharacterStream("name").use { + assertEquals("nchar named", it?.readText()) + } + } + + @Test + fun getClob() { + mockCursor.run { + whenever(isNull(1)) doReturn false + whenever(getString(1)) doReturn "clob data" + } + assertEquals("clob data", resultSet.getClob(2)?.getSubString(1, "clob data".length)) + } + + @Test + fun getClobByName() { + mockCursor.run { + whenever(isNull(1)) doReturn false + whenever(getString(1)) doReturn "named clob" + } + assertEquals("named clob", resultSet.getClob("name")?.getSubString(1, "named clob".length)) + } + + @Test + fun getClobNull(): Unit = resultSet.run { + whenever(mockCursor.isNull(1)) doReturn true + assertNull(getClob(2)) + assertTrue(wasNull()) + assertNull(getClob("name")) + assertTrue(wasNull()) + } + + @Test + fun getNString() { + mockCursor.run { + whenever(isNull(1)) doReturn false + whenever(getString(1)) doReturn "nstring" + } + resultSet.run { + assertEquals("nstring", getNString(2)) + assertEquals("nstring", getNString("name")) + } + } + + @Test + fun getObjectByName() { + mockCursor.run { + whenever(isNull(1)) doReturn false + whenever(getString(1)) doReturn "object by name" + } + assertEquals("object by name", resultSet.getObject("name")) + } + + @Test + fun getObjectWithType() { + mockCursor.run { + whenever(isNull(0)) doReturn false + whenever(getLong(0)) doReturn 42L + } + resultSet.run { + assertEquals(42, getObject(1, Int::class.java)) + assertEquals(42L, getObject(1, Long::class.java)) + assertEquals(42.0, getObject(1, Double::class.java)) + assertEquals(42.0f, getObject(1, Float::class.java)) + assertEquals("42", getObject(1, String::class.java)) + assertEquals(true, getObject(1, Boolean::class.java)) + } + } + + @Test + fun getObjectWithTypeByName() { + mockCursor.run { + whenever(isNull(0)) doReturn false + whenever(getLong(0)) doReturn 100L + } + assertEquals(100, resultSet.getObject("id", Int::class.java)) + } + + @Test + fun getObjectWithTypeNull() { + whenever(mockCursor.isNull(0)) doReturn true + assertNull(resultSet.getObject(1, String::class.java)) + } + + @Test + fun getObjectWithTypeBoolean() { + mockCursor.run { + whenever(isNull(1)) doReturn false + whenever(getString(1)) doReturn "true" + } + assertEquals(true, resultSet.getObject(2, Boolean::class.java)) + whenever(mockCursor.getString(1)) doReturn "1" + assertEquals(true, resultSet.getObject(2, Boolean::class.java)) + whenever(mockCursor.getString(1)) doReturn "false" + assertEquals(false, resultSet.getObject(2, Boolean::class.java)) + } + + @Test + fun getObjectWithTypeLocalDate() { + mockCursor.run { + whenever(isNull(1)) doReturn false + whenever(getString(1)) doReturn "2024-01-15" + } + assertEquals(LocalDate.parse("2024-01-15"), resultSet.getObject(2, LocalDate::class.java)) + } + + @Test + fun getObjectWithTypeLocalTime() { + mockCursor.run { + whenever(isNull(1)) doReturn false + whenever(getString(1)) doReturn "14:30:00" + } + assertEquals(LocalTime.parse("14:30:00"), resultSet.getObject(2, LocalTime::class.java)) + } + + @Test + fun getObjectWithTypeLocalDateTime() { + mockCursor.run { + whenever(isNull(1)) doReturn false + whenever(getString(1)) doReturn "2024-01-15T14:30:00" + } + assertEquals(LocalDateTime.parse("2024-01-15T14:30:00"), resultSet.getObject(2, LocalDateTime::class.java)) + } + + @Test + fun getObjectWithTypeInvalid() { + mockCursor.run { + whenever(isNull(1)) doReturn false + whenever(getString(1)) doReturn "test" + } + assertFailsWith { + resultSet.getObject(2, Blob::class.java) + } + } + + @Test + fun getObjectWithMap() { + mockCursor.run { + whenever(isNull(0)) doReturn false + whenever(getLong(0)) doReturn 123L + } + assertEquals(123L, resultSet.getObject(1, mutableMapOf())) + assertEquals(123L, resultSet.getObject("id", mutableMapOf())) + } + + @Test + fun getIntTypeConversions() { + mockCursor.run { + whenever(isNull(1)) doReturn false + whenever(getString(1)) doReturn "456" + } + assertEquals(456, resultSet.getInt(2)) + mockCursor.run { + whenever(isNull(3)) doReturn false + whenever(getDouble(3)) doReturn 789.99 + } + assertEquals(789, resultSet.getInt(4)) + whenever(mockCursor.isNull(0)) doReturn false + assertEquals(0, resultSet.getInt(1)) + } + + @Test + fun getDoubleTypeConversions() { + mockCursor.run { + whenever(isNull(0)) doReturn false + whenever(getInt(0)) doReturn 100 + } + assertEquals(100.0, resultSet.getDouble(1), 0.001) + mockCursor.run { + whenever(isNull(1)) doReturn false + whenever(getString(1)) doReturn "200.5" + } + assertEquals(200.5, resultSet.getDouble(2), 0.001) + whenever(mockCursor.isNull(3)) doReturn false + assertEquals(0.0, resultSet.getDouble(4), 0.001) + } + + @Test + fun scrollableResultSet() { + val scrollable = JdbcResultSet( + mockCursor, + mockStatement, + ResultSet.TYPE_SCROLL_INSENSITIVE, + ResultSet.CONCUR_READ_ONLY + ) + whenever(mockCursor.moveToFirst()) doReturn true + assertTrue(scrollable.first()) + whenever(mockCursor.moveToLast()) doReturn true + assertTrue(scrollable.last()) + whenever(mockCursor.moveToPrevious()) doReturn true + assertTrue(scrollable.previous()) + whenever(mockCursor.moveToPosition(4)) doReturn true + assertTrue(scrollable.absolute(5)) + whenever(mockCursor.move(3)) doReturn true + assertTrue(scrollable.relative(3)) + whenever(mockCursor.count) doReturn 10 + scrollable.afterLast() + scrollable.beforeFirst() + } + + @Test + fun rowStatus(): Unit = resultSet.run { + assertFalse(rowUpdated()) + assertFalse(rowInserted()) + assertFalse(rowDeleted()) + } + + @Test + fun unsupportedFeatures(): Unit = resultSet.run { + assertFailsWith { getNClob(1) } + assertFailsWith { getNClob("name") } + assertFailsWith { getSQLXML(1) } + assertFailsWith { getSQLXML("name") } + assertFailsWith { getURL(1) } + assertFailsWith { getURL("name") } + assertFailsWith { getArray(1) } + assertFailsWith { getArray("name") } + assertFailsWith { getBlob(1) } + assertFailsWith { getBlob("name") } + assertFailsWith { getRef(1) } + assertFailsWith { getRef("name") } + assertFailsWith { getRowId(1) } + assertFailsWith { getRowId("name") } + } + + @Test + fun allUpdateOperations(): Unit = resultSet.run { + assertFailsWith { updateNull(1) } + assertFailsWith { updateNull("name") } + assertFailsWith { updateBoolean(1, true) } + assertFailsWith { updateBoolean("name", true) } + assertFailsWith { updateByte(1, 1) } + assertFailsWith { updateByte("name", 1) } + assertFailsWith { updateShort(1, 1) } + assertFailsWith { updateShort("name", 1) } + assertFailsWith { updateInt(1, 1) } + assertFailsWith { updateInt("name", 1) } + assertFailsWith { updateLong(1, 1L) } + assertFailsWith { updateLong("name", 1L) } + assertFailsWith { updateFloat(1, 1.0f) } + assertFailsWith { updateFloat("name", 1.0f) } + assertFailsWith { updateDouble(1, 1.0) } + assertFailsWith { updateDouble("name", 1.0) } + assertFailsWith { updateBigDecimal(1, BigDecimal.ONE) } + assertFailsWith { updateBigDecimal("name", BigDecimal.ONE) } + assertFailsWith { updateBytes(1, byteArrayOf()) } + assertFailsWith { updateBytes("name", byteArrayOf()) } + assertFailsWith { updateDate(1, Date(0)) } + assertFailsWith { updateDate("name", Date(0)) } + assertFailsWith { updateTime(1, Time(0)) } + assertFailsWith { updateTime("name", Time(0)) } + assertFailsWith { updateTimestamp(1, Timestamp(0)) } + assertFailsWith { updateTimestamp("name", Timestamp(0)) } + assertFailsWith { updateRow() } + assertFailsWith { refreshRow() } + assertFailsWith { cancelRowUpdates() } + assertFailsWith { moveToInsertRow() } + assertFailsWith { moveToCurrentRow() } + } + + @Test + fun streamUpdateOperations(): Unit = resultSet.run { + val stream = mock() + val reader = mock() + assertFailsWith { updateAsciiStream(1, stream, 10) } + assertFailsWith { updateAsciiStream("name", stream, 10) } + assertFailsWith { updateAsciiStream(1, stream, 10L) } + assertFailsWith { updateAsciiStream("name", stream, 10L) } + assertFailsWith { updateAsciiStream(1, stream) } + assertFailsWith { updateAsciiStream("name", stream) } + assertFailsWith { updateBinaryStream(1, stream, 10) } + assertFailsWith { updateBinaryStream("name", stream, 10) } + assertFailsWith { updateBinaryStream(1, stream, 10L) } + assertFailsWith { updateBinaryStream("name", stream, 10L) } + assertFailsWith { updateBinaryStream(1, stream) } + assertFailsWith { updateBinaryStream("name", stream) } + assertFailsWith { updateCharacterStream(1, reader, 10) } + assertFailsWith { updateCharacterStream("name", reader, 10) } + assertFailsWith { updateCharacterStream(1, reader, 10L) } + assertFailsWith { updateCharacterStream("name", reader, 10L) } + assertFailsWith { updateCharacterStream(1, reader) } + assertFailsWith { updateCharacterStream("name", reader) } + assertFailsWith { updateNCharacterStream(1, reader, 10L) } + assertFailsWith { updateNCharacterStream("name", reader, 10L) } + assertFailsWith { updateNCharacterStream(1, reader) } + assertFailsWith { updateNCharacterStream("name", reader) } + } + + @Test + fun lobUpdateOperations(): Unit = resultSet.run { + val ref = mock() + val blob = mock() + val clob = mock() + val nclob = mock() + val array = mock() + val rowId = mock() + val sqlxml = mock() + val stream = mock() + val reader = mock() + + assertFailsWith { updateRef(1, ref) } + assertFailsWith { updateRef("name", ref) } + assertFailsWith { updateBlob(1, blob) } + assertFailsWith { updateBlob("name", blob) } + assertFailsWith { updateBlob(1, stream, 10L) } + assertFailsWith { updateBlob("name", stream, 10L) } + assertFailsWith { updateBlob(1, stream) } + assertFailsWith { updateBlob("name", stream) } + assertFailsWith { updateClob(1, clob) } + assertFailsWith { updateClob("name", clob) } + assertFailsWith { updateClob(1, reader, 10L) } + assertFailsWith { updateClob("name", reader, 10L) } + assertFailsWith { updateClob(1, reader) } + assertFailsWith { updateClob("name", reader) } + assertFailsWith { updateArray(1, array) } + assertFailsWith { updateArray("name", array) } + assertFailsWith { updateRowId(1, rowId) } + assertFailsWith { updateRowId("name", rowId) } + assertFailsWith { updateNString(1, "test") } + assertFailsWith { updateNString("name", "test") } + assertFailsWith { updateNClob(1, nclob) } + assertFailsWith { updateNClob("name", nclob) } + assertFailsWith { updateNClob(1, reader, 10L) } + assertFailsWith { updateNClob("name", reader, 10L) } + assertFailsWith { updateNClob(1, reader) } + assertFailsWith { updateNClob("name", reader) } + assertFailsWith { updateSQLXML(1, sqlxml) } + assertFailsWith { updateSQLXML("name", sqlxml) } + assertFailsWith { updateObject(1, "test", 1) } + assertFailsWith { updateObject(1, "test") } + assertFailsWith { updateObject("name", "test", 1) } + assertFailsWith { updateObject("name", "test") } + } + + @Test + fun unwrapCursor(): Unit = resultSet.run { + assertTrue(isWrapperFor(ICursor::class.java)) + assertEquals(mockCursor, unwrap(ICursor::class.java)) + } + + @Test + fun getBooleanWithString() { + mockCursor.apply { + whenever(isNull(0)) doReturn false + whenever(getLong(0)) doThrow RuntimeException("Not a long") + whenever(getString(0)) doReturn "true" + } + assertTrue(resultSet.getBoolean(1)) + whenever(mockCursor.getString(0)) doReturn "false" + assertFalse(resultSet.getBoolean(1)) + whenever(mockCursor.getString(0)) doReturn null + assertFalse(resultSet.getBoolean(1)) + } + + @Test + fun getRowReturnsPosition() { + mockCursor.apply { + whenever(isBeforeFirst()) doReturn false + whenever(isAfterLast()) doReturn false + whenever(position()) doReturn 0 + } + assertEquals(1, resultSet.row) + } + + @Test + fun getRowWhenBeforeFirst() { + mockCursor.apply { + whenever(isBeforeFirst()) doReturn true + whenever(isAfterLast()) doReturn false + } + assertEquals(0, resultSet.row) + } + + @Test + fun getRowWhenAfterLast() { + mockCursor.apply { + whenever(isBeforeFirst()) doReturn false + whenever(isAfterLast()) doReturn true + } + assertEquals(0, resultSet.row) + } + + @Test + fun isLastOnForwardOnlyResultSet() { + assertFalse(resultSet.isLast) + } + + @Test + fun rowDeletedInsertedUpdated(): Unit = resultSet.run { + assertFalse(rowDeleted()) + assertFalse(rowInserted()) + assertFalse(rowUpdated()) + } + + @Test + fun getCursorName() { + assertFailsWith { + resultSet.cursorName + } + } + + @Test + fun getStatementWhenNull() { + assertNull(JdbcResultSet(mockCursor, null).statement) + } + + @Test + fun getHoldability() { + assertEquals(ResultSet.CLOSE_CURSORS_AT_COMMIT, resultSet.holdability) + } + + @Test + fun getWarnings(): Unit = resultSet.run { + assertNull(warnings) + clearWarnings() + assertNull(warnings) + } + + @Test + fun unwrapInvalidClass() { + assertFailsWith { + resultSet.unwrap(String::class.java) + } + } + + @Test + fun isWrapperForInvalidClass() { + assertFalse(resultSet.isWrapperFor(String::class.java)) + } + + @Test + fun findColumnInvalid() { + whenever(mockCursor.columnNames()) doReturn arrayOf("id", "name") + assertFailsWith { + resultSet.findColumn("invalid_column") + } + } + + @Test + fun getObjectWithNullType() { + whenever(mockCursor.isNull(0)) doReturn true + resultSet.run { + assertNull(getObject(1, String::class.java)) + assertTrue(wasNull()) + } + } + + @Test + fun getURLThrows(): Unit = resultSet.run { + assertFailsWith { + getURL(1) + } + assertFailsWith { + getURL("url") + } + } + + @Test + fun getRefThrows(): Unit = resultSet.run { + assertFailsWith { + getRef(1) + } + assertFailsWith { + getRef("name") + } + } + + @Test + fun getBlobThrows(): Unit = resultSet.run { + assertFailsWith { + getBlob(1) + } + assertFailsWith { + getBlob("name") + } + } + + @Test + fun getArrayThrows(): Unit = resultSet.run { + assertFailsWith { + getArray(1) + } + assertFailsWith { + getArray("name") + } + } + + @Test + fun getRowIdThrows(): Unit = resultSet.run { + assertFailsWith { + getRowId(1) + } + assertFailsWith { + getRowId("name") + } + } + + @Test + fun getNStringDelegatesToGetString() { + mockCursor.apply { + whenever(isNull(0)) doReturn false + whenever(getString(0)) doReturn "test" + } + assertEquals("test", resultSet.getNString(1)) + mockCursor.apply { + whenever(isNull(1)) doReturn false + whenever(getString(1)) doReturn "test" + } + assertEquals("test", resultSet.getNString("name")) + } + + @Test + fun getNCharacterStreamDelegatesToGetCharacterStream() { + mockCursor.run { + whenever(isNull(0)) doReturn false + whenever(getString(0)) doReturn "test data" + } + assertNotNull(resultSet.getNCharacterStream(1)) + mockCursor.run { + whenever(isNull(1)) doReturn false + whenever(getString(1)) doReturn "test data" + } + assertNotNull(resultSet.getNCharacterStream("name")) + } + + @Test + fun getNClobThrows(): Unit = resultSet.run { + assertFailsWith { + getNClob(1) + } + assertFailsWith { + getNClob("name") + } + } + + @Test + fun getSQLXMLThrows(): Unit = resultSet.run { + assertFailsWith { + getSQLXML(1) + } + assertFailsWith { + getSQLXML("name") + } + } + + @Test + fun updateNCharacterStreamThrows(): Unit = resultSet.run { + assertFailsWith { + updateNCharacterStream(1, "data".reader(), 4L) + } + assertFailsWith { + updateNCharacterStream("name", "data".reader(), 4L) + } + assertFailsWith { + updateNCharacterStream(1, "data".reader()) + } + assertFailsWith { + updateNCharacterStream("name", "data".reader()) + } + } + + @Test + fun getObjectWithInvalidDateFormat() { + mockCursor.apply { + whenever(isNull(0)) doReturn false + whenever(getString(0)) doReturn "invalid-date" + } + resultSet.run { + assertFailsWith { + getObject(1, LocalDate::class.java) + } + assertFailsWith { + getObject(1, LocalTime::class.java) + } + assertFailsWith { + getObject(1, LocalDateTime::class.java) + } + } + } + + @Test + fun getMetaDataReturnsCorrectInstance() { + resultSet.metaData.let { + assertNotNull(it) + assertSame(it, resultSet.metaData) + } + } + + @Test + fun setFetchDirection(): Unit = resultSet.run { + fetchDirection = ResultSet.FETCH_FORWARD + assertEquals(ResultSet.FETCH_FORWARD, fetchDirection) + assertFailsWith { + fetchDirection = ResultSet.FETCH_REVERSE + } + } + + @Test + fun setFetchSize(): Unit = resultSet.run { + fetchSize = 100 + assertEquals(100, fetchSize) + assertFailsWith { + fetchSize = -1 + } + } + + @Test + fun typeAndConcurrency(): Unit = resultSet.run { + assertEquals(ResultSet.TYPE_FORWARD_ONLY, type) + assertEquals(ResultSet.CONCUR_READ_ONLY, concurrency) + } +} diff --git a/selekt-jdbc/src/test/kotlin/com/bloomberg/selekt/jdbc/statement/JdbcParameterMetaDataTest.kt b/selekt-jdbc/src/test/kotlin/com/bloomberg/selekt/jdbc/statement/JdbcParameterMetaDataTest.kt new file mode 100644 index 0000000000..8df949684b --- /dev/null +++ b/selekt-jdbc/src/test/kotlin/com/bloomberg/selekt/jdbc/statement/JdbcParameterMetaDataTest.kt @@ -0,0 +1,220 @@ +/* + * Copyright 2026 Bloomberg Finance L.P. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.bloomberg.selekt.jdbc.statement + +import org.junit.jupiter.api.Test +import java.sql.ParameterMetaData +import java.sql.SQLException +import java.sql.Types +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertFalse +import kotlin.test.assertSame +import kotlin.test.assertTrue + +internal class JdbcParameterMetaDataTest { + @Test + fun getParameterCount() { + val metaData = JdbcParameterMetaData(5) + assertEquals(5, metaData.parameterCount) + } + + @Test + fun getParameterCountZero() { + val metaData = JdbcParameterMetaData(0) + assertEquals(0, metaData.parameterCount) + } + + @Test + fun isNullable(): Unit = JdbcParameterMetaData(3).run { + assertEquals(ParameterMetaData.parameterNullableUnknown, isNullable(1)) + assertEquals(ParameterMetaData.parameterNullableUnknown, isNullable(2)) + assertEquals(ParameterMetaData.parameterNullableUnknown, isNullable(3)) + } + + @Test + fun isNullableInvalidIndex(): Unit = JdbcParameterMetaData(2).run { + assertFailsWith { + isNullable(0) + } + assertFailsWith { + isNullable(3) + } + assertFailsWith { + isNullable(-1) + } + } + + @Test + fun isSigned(): Unit = JdbcParameterMetaData(3).run { + assertFalse(isSigned(1)) + assertFalse(isSigned(2)) + assertFalse(isSigned(3)) + } + + @Test + fun isSignedInvalidIndex(): Unit = JdbcParameterMetaData(2).run { + assertFailsWith { + isSigned(0) + } + assertFailsWith { + isSigned(3) + } + } + + @Test + fun getPrecision(): Unit = JdbcParameterMetaData(3).run { + assertEquals(0, getPrecision(1)) + assertEquals(0, getPrecision(2)) + assertEquals(0, getPrecision(3)) + } + + @Test + fun getPrecisionInvalidIndex(): Unit = JdbcParameterMetaData(2).run { + assertFailsWith { + getPrecision(0) + } + assertFailsWith { + getPrecision(3) + } + } + + @Test + fun getScale(): Unit = JdbcParameterMetaData(3).run { + assertEquals(0, getScale(1)) + assertEquals(0, getScale(2)) + assertEquals(0, getScale(3)) + } + + @Test + fun getScaleInvalidIndex(): Unit = JdbcParameterMetaData(2).run { + assertFailsWith { + getScale(0) + } + assertFailsWith { + getScale(3) + } + } + + @Test + fun getParameterType(): Unit = JdbcParameterMetaData(3).run { + assertEquals(Types.VARCHAR, getParameterType(1)) + assertEquals(Types.VARCHAR, getParameterType(2)) + assertEquals(Types.VARCHAR, getParameterType(3)) + } + + @Test + fun getParameterTypeInvalidIndex(): Unit = JdbcParameterMetaData(2).run { + assertFailsWith { + getParameterType(0) + } + assertFailsWith { + getParameterType(3) + } + } + + @Test + fun getParameterTypeName(): Unit = JdbcParameterMetaData(3).run { + assertEquals("VARCHAR", getParameterTypeName(1)) + assertEquals("VARCHAR", getParameterTypeName(2)) + assertEquals("VARCHAR", getParameterTypeName(3)) + } + + @Test + fun getParameterTypeNameInvalidIndex(): Unit = JdbcParameterMetaData(2).run { + assertFailsWith { + getParameterTypeName(0) + } + assertFailsWith { + getParameterTypeName(3) + } + } + + @Test + fun getParameterClassName(): Unit = JdbcParameterMetaData(3).run { + assertEquals(String::class.java.name, getParameterClassName(1)) + assertEquals(String::class.java.name, getParameterClassName(2)) + assertEquals(String::class.java.name, getParameterClassName(3)) + } + + @Test + fun getParameterClassNameInvalidIndex(): Unit = JdbcParameterMetaData(2).run { + assertFailsWith { + getParameterClassName(0) + } + assertFailsWith { + getParameterClassName(3) + } + } + + @Test + fun getParameterMode(): Unit = JdbcParameterMetaData(3).run { + assertEquals(ParameterMetaData.parameterModeIn, getParameterMode(1)) + assertEquals(ParameterMetaData.parameterModeIn, getParameterMode(2)) + assertEquals(ParameterMetaData.parameterModeIn, getParameterMode(3)) + } + + @Test + fun getParameterModeInvalidIndex(): Unit = JdbcParameterMetaData(2).run { + assertFailsWith { + getParameterMode(0) + } + assertFailsWith { + getParameterMode(3) + } + } + + @Test + fun unwrap(): Unit = JdbcParameterMetaData(2).run { + assertSame(this, unwrap(JdbcParameterMetaData::class.java)) + } + + @Test + fun unwrapToParameterMetaData(): Unit = JdbcParameterMetaData(2).run { + assertSame(this, unwrap(ParameterMetaData::class.java)) + } + + @Test + fun unwrapInvalidClass(): Unit = JdbcParameterMetaData(2).run { + assertFailsWith { + unwrap(String::class.java) + } + } + + @Test + fun isWrapperFor(): Unit = JdbcParameterMetaData(2).run { + assertTrue(isWrapperFor(JdbcParameterMetaData::class.java)) + assertTrue(isWrapperFor(ParameterMetaData::class.java)) + } + + @Test + fun isWrapperForInvalidClass() { + assertFalse(JdbcParameterMetaData(2).isWrapperFor(String::class.java)) + } + + @Test + fun validateParameterIndexBoundaries(): Unit = JdbcParameterMetaData(5).run { + getParameterType(1) + getParameterType(5) + assertFailsWith { + getParameterType(0) + } + assertFailsWith { + getParameterType(6) + } + } +} diff --git a/selekt-jdbc/src/test/kotlin/com/bloomberg/selekt/jdbc/statement/JdbcPreparedStatementTest.kt b/selekt-jdbc/src/test/kotlin/com/bloomberg/selekt/jdbc/statement/JdbcPreparedStatementTest.kt new file mode 100644 index 0000000000..4c9b422ae3 --- /dev/null +++ b/selekt-jdbc/src/test/kotlin/com/bloomberg/selekt/jdbc/statement/JdbcPreparedStatementTest.kt @@ -0,0 +1,939 @@ +/* + * Copyright 2026 Bloomberg Finance L.P. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.bloomberg.selekt.jdbc.statement + +import com.bloomberg.selekt.ICursor +import com.bloomberg.selekt.ISQLStatement +import com.bloomberg.selekt.SQLDatabase +import com.bloomberg.selekt.jdbc.connection.JdbcConnection +import com.bloomberg.selekt.jdbc.result.JdbcResultSet +import com.bloomberg.selekt.jdbc.util.ConnectionURL +import java.io.InputStream +import java.io.Reader +import java.math.BigDecimal +import java.net.URI +import java.sql.Blob +import java.sql.Clob +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.mockito.kotlin.any +import org.mockito.kotlin.doAnswer +import org.mockito.kotlin.doThrow +import org.mockito.kotlin.mock +import org.mockito.kotlin.times +import org.mockito.kotlin.verify +import org.mockito.kotlin.whenever +import java.sql.Date +import java.sql.NClob +import java.sql.PreparedStatement +import java.sql.Ref +import java.sql.ResultSet +import java.sql.RowId +import java.sql.SQLException +import java.sql.SQLXML +import java.sql.Time +import java.sql.Timestamp +import java.sql.Types +import java.util.Properties +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertFalse +import kotlin.test.assertNotNull +import kotlin.test.assertTrue +import org.junit.jupiter.api.AfterEach +import org.mockito.kotlin.doReturn + +internal class JdbcPreparedStatementTest { + private lateinit var database: SQLDatabase + private lateinit var connection: JdbcConnection + private lateinit var cursor: ICursor + private lateinit var preparedStatement: JdbcPreparedStatement + + @BeforeEach + fun setUp() { + database = mock() + cursor = mock() + val connectionURL = ConnectionURL.parse("jdbc:selekt:/tmp/test.db") + val properties = Properties() + connection = JdbcConnection(database, connectionURL, properties) + val sql = "SELECT * FROM users WHERE id = ? AND name = ?" + preparedStatement = JdbcPreparedStatement(connection, database, sql) + } + + @AfterEach + fun tearDown() { + preparedStatement.close() + connection.close() + } + + @Test + fun executeQuery() { + whenever(database.query(any(), any>())) doReturn cursor + assertTrue(preparedStatement.apply { + setInt(1, 42) + setString(2, "test") + }.executeQuery() is JdbcResultSet) + verify(database).query(any(), any>()) + } + + @Test + fun executeUpdate() { + val mockStatement = mock() + whenever(database.compileStatement(any(), any>())) doReturn mockStatement + whenever(mockStatement.executeUpdateDelete()) doReturn 2 + assertEquals(2, preparedStatement.apply { + setInt(1, 42) + setString(2, "updated") + }.executeUpdate()) + verify(database).compileStatement(any(), any>()) + verify(mockStatement).executeUpdateDelete() + } + + @Test + fun execute() { + whenever(database.query(any(), any>())) doReturn cursor + assertTrue(preparedStatement.apply { + setInt(1, 42) + setString(2, "test") + }.execute()) + verify(database).query(any(), any>()) + } + + @Test + fun executeWithUpdateStatement() { + val mockStatement = mock() + whenever(database.compileStatement(any(), any>())) doReturn mockStatement + whenever(mockStatement.executeUpdateDelete()) doReturn 1 + assertFalse(JdbcPreparedStatement( + connection, + database, + "UPDATE users SET name = ? WHERE id = ?" + ).apply { + setString(1, "updated") + setInt(2, 42) + }.use(PreparedStatement::execute)) + verify(database).compileStatement(any(), any>()) + verify(mockStatement).executeUpdateDelete() + } + + @Test + fun executeQueryWithException() { + whenever(database.query(any(), any>())) doThrow RuntimeException("Query failed") + assertFailsWith { + preparedStatement.apply { + setInt(1, 42) + setString(2, "test") + }.executeQuery() + } + } + + @Test + fun executeUpdateWithException() { + val mockStatement = mock() + whenever(database.compileStatement(any(), any>())) doReturn mockStatement + whenever(mockStatement.executeUpdateDelete()) doThrow RuntimeException("Update failed") + assertFailsWith { + preparedStatement.apply { + setInt(1, 42) + setString(2, "test") + }.executeUpdate() + } + } + + @Test + fun executeWithException() { + whenever(database.query(any(), any>())) doThrow RuntimeException("Execute failed") + assertFailsWith { + preparedStatement.apply { + setInt(1, 42) + setString(2, "test") + }.execute() + } + } + + @Test + fun parameterBinding() { + whenever(database.query(any(), any>())) doReturn cursor + preparedStatement.run { + setNull(1, Types.VARCHAR) + setBoolean(2, true) + executeQuery() + } + verify(database).query(any(), any>()) + } + + @Test + fun dateTimeParameters() { + val sql = "SELECT * FROM events WHERE event_date = ? AND event_time = ? AND event_timestamp = ?" + val statement = JdbcPreparedStatement(connection, database, sql) + whenever(database.query(any(), any>())) doReturn cursor + statement.apply { + setDate(1, Date.valueOf("2026-02-04")) + setTime(2, Time.valueOf("10:30:45")) + setTimestamp(3, Timestamp.valueOf("2026-02-04 10:30:45")) + executeQuery() + } + verify(database).query(any(), any>()) + } + + @Test + fun setObject() { + whenever(database.query(any(), any>())) doReturn cursor + preparedStatement.apply { + setObject(1, 42) + setObject(2, "string") + }.executeQuery() + verify(database).query(any(), any>()) + } + + @Test + fun setObjectWithTargetType() { + whenever(database.query(any(), any>())) doReturn cursor + preparedStatement.apply { + setObject(1, "42", Types.INTEGER) + setObject(2, 42, Types.VARCHAR) + }.executeQuery() + verify(database).query(any(), any>()) + } + + @Test + fun clearParameters(): Unit = preparedStatement.run { + setInt(1, 42) + setString(2, "test") + clearParameters() + whenever(database.query(any(), any>())) doReturn cursor + setInt(1, 100) + executeQuery() + verify(database).query(any(), any>()) + } + + @Test + fun preparedStatementReuse() { + whenever(database.query(any(), any>())) doReturn cursor + preparedStatement.apply { + setInt(1, 42) + setString(2, "first") + }.executeQuery() + preparedStatement.apply { + setInt(1, 100) + setString(2, "second") + }.executeQuery() + preparedStatement.apply { + clearParameters() + setInt(1, 200) + setString(2, "third") + }.executeQuery() + verify(database, times(3)).query(any(), any>()) + } + + @Test + fun getParameterMetaData() { + assertNotNull(preparedStatement.parameterMetaData) + } + + @Test + fun getMetaData() { + assertFailsWith { + preparedStatement.metaData + } + } + + @Test + fun executeBatch() { + val batchSql = "UPDATE users SET name = ? WHERE id = ?" + val batchStatement = JdbcPreparedStatement(connection, database, batchSql).apply { + setString(1, "first") + setInt(2, 1) + addBatch() + setString(1, "second") + setInt(2, 2) + addBatch() + } + whenever(database.batch(any(), any>>())) doAnswer { invocation -> + invocation.getArgument>>(1).count() + } + batchStatement.executeBatch().run { + assertEquals(2, size) + assertEquals(-2, first()) + assertEquals(-2, this[1]) + } + verify(database).batch(any(), any>>()) + } + + @Test + fun executeBatchWithChunkExpansion() { + val batchSql = "UPDATE users SET value = ? WHERE id = ?" + val batchStatement = JdbcPreparedStatement(connection, database, batchSql).apply { + repeat(150) { i -> + setInt(1, i) + setInt(2, i) + addBatch() + } + } + whenever(database.batch(any(), any>>())) doAnswer { invocation -> + invocation.getArgument>>(1).count() + } + val updateCounts = batchStatement.executeBatch() + assertEquals(150, updateCounts.size) + verify(database).batch(any(), any>>()) + } + + @Test + fun invalidParameterIndex() { + assertFailsWith { + preparedStatement.setInt(0, 42) + } + assertFailsWith { + preparedStatement.setString(-1, "test") + } + } + + @Test + fun executeWithoutParameters() { + whenever(database.query(any(), any>())) doReturn cursor + assertTrue(JdbcPreparedStatement(connection, database, "SELECT COUNT(*) FROM users").execute()) + } + + @Test + fun unsupportedMethods(): Unit = preparedStatement.run { + assertFailsWith { + setRef(1, mock()) + } + assertFailsWith { + setBlob(1, mock()) + } + assertFailsWith { + setArray(1, mock()) + } + assertFailsWith { + setURL(1, URI.create("http://example.com").toURL()) + } + assertFailsWith { + setRowId(1, mock()) + } + assertFailsWith { + setNString(1, "test") + } + assertFailsWith { + setNCharacterStream(1, mock()) + } + assertFailsWith { + setNClob(1, mock()) + } + assertFailsWith { + setSQLXML(1, mock()) + } + } + + @Test + fun closure(): Unit = preparedStatement.run { + assertFalse(isClosed) + close() + assertTrue(isClosed) + } + + @Test + fun operationsAfterClose(): Unit = preparedStatement.run { + close() + assertFailsWith { + executeQuery() + } + assertFailsWith { + setInt(1, 42) + } + } + + @Test + fun wrapperInterface(): Unit = preparedStatement.run { + assertTrue(isWrapperFor(JdbcPreparedStatement::class.java)) + assertFalse(isWrapperFor(String::class.java)) + assertEquals(preparedStatement, unwrap(JdbcPreparedStatement::class.java)) + assertFailsWith { + preparedStatement.unwrap(String::class.java) + } + } + + @Test + fun getSql() { + assertEquals("SELECT * FROM users WHERE id = ? AND name = ?", preparedStatement.sql) + } + + @Test + fun parameterIndexValidation() { + whenever(database.query(any(), any>())) doReturn cursor + preparedStatement.apply { + setInt(1, 42) + setString(2, "test") + }.executeQuery() + verify(database).query(any(), any>()) + } + + @Test + fun repeatExecute() { + val mockStatement = mock { + whenever(it.executeUpdateDelete()) doReturn 1 + } + whenever(database.compileStatement(any(), any>())) doReturn mockStatement + val resultOne = preparedStatement.apply { + setInt(1, 1) + setString(2, "first") + }.executeUpdate() + val resultTwo = preparedStatement.apply { + setInt(1, 2) + setString(2, "second") + }.executeUpdate() + assertEquals(1, resultOne) + assertEquals(1, resultTwo) + verify(database, times(2)).compileStatement(any(), any>()) + } + + @Test + fun setAllBasicTypes() { + whenever(database.query(any(), any>())) doReturn cursor + JdbcPreparedStatement( + connection, + database, + "SELECT * FROM test WHERE a=? AND b=? AND c=? AND d=? AND e=? AND f=?" + ).run { + setBoolean(1, true) + setByte(2, 42.toByte()) + setShort(3, 100.toShort()) + setLong(4, 1000L) + setFloat(5, 1.5f) + setDouble(6, 2.5) + executeQuery() + } + verify(database).query(any(), any>()) + } + + @Test + fun setBinaryData() { + whenever(database.query(any(), any>())) doReturn cursor + JdbcPreparedStatement( + connection, + database, + "SELECT * FROM test WHERE data=?" + ).run { + setBytes(1, byteArrayOf(1, 2, 3, 4)) + executeQuery() + } + verify(database).query(any(), any>()) + } + + @Test + fun setStreams() { + whenever(database.query(any(), any>())) doReturn cursor + JdbcPreparedStatement( + connection, + database, + "SELECT * FROM test WHERE a=? AND b=?" + ).run { + "test".byteInputStream().use { asciiStream -> + byteArrayOf(1, 2, 3).inputStream().use { binaryStream -> + setAsciiStream(1, asciiStream) + setBinaryStream(2, binaryStream) + } + } + executeQuery() + } + verify(database).query(any(), any>()) + } + + @Test + fun setCharacterStream() { + whenever(database.query(any(), any>())) doReturn cursor + JdbcPreparedStatement( + connection, + database, + "SELECT * FROM test WHERE text=?" + ).run { + "test data".reader().use { + setCharacterStream(1, it) + } + executeQuery() + } + verify(database).query(any(), any>()) + } + + @Test + fun setClob() { + whenever(database.query(any(), any>())) doReturn cursor + JdbcPreparedStatement( + connection, + database, + "SELECT * FROM test WHERE clob_data=?" + ).run { + setClob(1, mock { + whenever(it.getSubString(any(), any())) doReturn "clob data" + whenever(it.length()) doReturn 9L + }) + executeQuery() + } + verify(database).query(any(), any>()) + } + + @Test + fun setObjectWithScaleAndTargetType() { + whenever(database.query(any(), any>())) doReturn cursor + JdbcPreparedStatement( + connection, + database, + "SELECT * FROM test WHERE value=?" + ).run { + setObject(1, 123.456, Types.DECIMAL, 2) + executeQuery() + } + verify(database).query(any(), any>()) + } + + @Test + fun addBatchWithoutParameters() { + whenever(database.batch(any(), any>>())) doReturn 2 + val results = JdbcPreparedStatement( + connection, + database, + "UPDATE test SET value=? WHERE id=?" + ).run { + setInt(1, 100) + setInt(2, 1) + addBatch() + setInt(1, 200) + setInt(2, 2) + addBatch() + executeBatch() + } + assertEquals(2, results.size) + verify(database).batch(any(), any>>()) + } + + @Test + fun clearBatch() { + whenever(database.batch(any(), any>>())) doReturn 0 + val results = JdbcPreparedStatement( + connection, + database, + "UPDATE test SET value=? WHERE id=?" + ).run { + setInt(1, 100) + setInt(2, 1) + addBatch() + clearBatch() + executeBatch() + } + assertEquals(0, results.size) + } + + @Test + fun setNullWithTypeName() { + whenever(database.query(any(), any>())) doReturn cursor + JdbcPreparedStatement( + connection, + database, + "SELECT * FROM test WHERE value=?" + ).run { + setNull(1, Types.VARCHAR, "VARCHAR") + executeQuery() + } + verify(database).query(any(), any>()) + } + + @Test + fun executeWithUnsetParameters() { + whenever(database.query(any(), any>())) doReturn cursor + JdbcPreparedStatement( + connection, + database, + "SELECT * FROM test WHERE a=? AND b=?" + ).run { + setInt(1, 42) + execute() + } + verify(database).query(any(), any>()) + } + + @Test + fun setAsciiStreamWithLength() { + whenever(database.query(any(), any>())) doReturn cursor + JdbcPreparedStatement(connection, database, "SELECT * FROM test WHERE text=?").run { + "test".byteInputStream().use { + setAsciiStream(1, it, 4) + } + executeQuery() + } + verify(database).query(any(), any>()) + } + + @Test + fun setAsciiStreamWithLongLength() { + whenever(database.query(any(), any>())) doReturn cursor + JdbcPreparedStatement(connection, database, "SELECT * FROM test WHERE text=?").run { + "test".byteInputStream().use { + setAsciiStream(1, it, 4L) + } + executeQuery() + } + verify(database).query(any(), any>()) + } + + @Test + fun setBinaryStreamWithLength() { + whenever(database.query(any(), any>())) doReturn cursor + JdbcPreparedStatement(connection, database, "SELECT * FROM test WHERE data=?").run { + byteArrayOf(1, 2, 3, 4).inputStream().use { + setBinaryStream(1, it, 4) + } + executeQuery() + } + verify(database).query(any(), any>()) + } + + @Test + fun setBinaryStreamWithLongLength() { + whenever(database.query(any(), any>())) doReturn cursor + JdbcPreparedStatement(connection, database, "SELECT * FROM test WHERE data=?").run { + byteArrayOf(1, 2, 3, 4).inputStream().use { + setBinaryStream(1, it, 4L) + } + executeQuery() + } + verify(database).query(any(), any>()) + } + + @Test + fun setUnicodeStream() { + whenever(database.query(any(), any>())) doReturn cursor + JdbcPreparedStatement(connection, database, "SELECT * FROM test WHERE text=?").run { + "test".byteInputStream().use { + @Suppress("DEPRECATION") + setUnicodeStream(1, it, 4) + } + executeQuery() + } + verify(database).query(any(), any>()) + } + + @Test + fun setCharacterStreamWithLength() { + whenever(database.query(any(), any>())) doReturn cursor + JdbcPreparedStatement(connection, database, "SELECT * FROM test WHERE text=?").run { + "test data".reader().use { + setCharacterStream(1, it, 9) + } + executeQuery() + } + verify(database).query(any(), any>()) + } + + @Test + fun setCharacterStreamWithLongLength() { + whenever(database.query(any(), any>())) doReturn cursor + JdbcPreparedStatement(connection, database, "SELECT * FROM test WHERE text=?").run { + "test data".reader().use { + setCharacterStream(1, it, 9L) + } + executeQuery() + } + verify(database).query(any(), any>()) + } + + @Test + fun setDateWithCalendar() { + whenever(database.query(any(), any>())) doReturn cursor + JdbcPreparedStatement(connection, database, "SELECT * FROM test WHERE date=?").run { + setDate(1, Date.valueOf("2026-02-04"), null) + executeQuery() + } + verify(database).query(any(), any>()) + } + + @Test + fun setTimeWithCalendar() { + whenever(database.query(any(), any>())) doReturn cursor + JdbcPreparedStatement(connection, database, "SELECT * FROM test WHERE time=?").run { + setTime(1, Time.valueOf("10:30:45"), null) + executeQuery() + } + verify(database).query(any(), any>()) + } + + @Test + fun setTimestampWithCalendar() { + whenever(database.query(any(), any>())) doReturn cursor + JdbcPreparedStatement(connection, database, "SELECT * FROM test WHERE ts=?").run { + setTimestamp(1, Timestamp.valueOf("2026-02-04 10:30:45"), null) + executeQuery() + } + verify(database).query(any(), any>()) + } + + @Test + fun setClobWithReader() { + whenever(database.query(any(), any>())) doReturn cursor + JdbcPreparedStatement(connection, database, "SELECT * FROM test WHERE clob=?").run { + "clob content".reader().use { + setClob(1, it) + } + executeQuery() + } + verify(database).query(any(), any>()) + } + + @Test + fun setClobWithReaderAndLength() { + whenever(database.query(any(), any>())) doReturn cursor + JdbcPreparedStatement(connection, database, "SELECT * FROM test WHERE clob=?").run { + "clob content".reader().use { + setClob(1, it, 12L) + } + executeQuery() + } + verify(database).query(any(), any>()) + } + + @Test + fun setClobNull() { + whenever(database.query(any(), any>())) doReturn cursor + JdbcPreparedStatement(connection, database, "SELECT * FROM test WHERE clob=?").run { + setClob(1, null as Clob?) + executeQuery() + } + verify(database).query(any(), any>()) + } + + @Test + fun setClobReaderNull() { + whenever(database.query(any(), any>())) doReturn cursor + JdbcPreparedStatement(connection, database, "SELECT * FROM test WHERE clob=?").run { + setClob(1, null as Reader?) + executeQuery() + } + verify(database).query(any(), any>()) + } + + @Test + fun setClobReaderWithLengthNull() { + whenever(database.query(any(), any>())) doReturn cursor + JdbcPreparedStatement(connection, database, "SELECT * FROM test WHERE clob=?").run { + setClob(1, null as Reader?, 0L) + executeQuery() + } + verify(database).query(any(), any>()) + } + + @Test + fun setBigDecimal() { + whenever(database.query(any(), any>())) doReturn cursor + JdbcPreparedStatement(connection, database, "SELECT * FROM test WHERE amount=?").run { + setBigDecimal(1, BigDecimal("123.45")) + executeQuery() + } + verify(database).query(any(), any>()) + } + + @Test + fun setBigDecimalNull() { + whenever(database.query(any(), any>())) doReturn cursor + JdbcPreparedStatement(connection, database, "SELECT * FROM test WHERE amount=?").run { + setBigDecimal(1, null) + executeQuery() + } + verify(database).query(any(), any>()) + } + + @Test + fun setNullValues() { + whenever(database.query(any(), any>())) doReturn cursor + JdbcPreparedStatement( + connection, + database, + "SELECT * FROM test WHERE a=? AND b=? AND c=? AND d=? AND e=?" + ).run { + setString(1, null) + setBytes(2, null) + setDate(3, null) + setTime(4, null) + setTimestamp(5, null) + executeQuery() + } + verify(database).query(any(), any>()) + } + + @Test + fun setNullStreams() { + whenever(database.query(any(), any>())) doReturn cursor + JdbcPreparedStatement( + connection, + database, + "SELECT * FROM test WHERE a=? AND b=? AND c=?" + ).run { + setAsciiStream(1, null as InputStream?) + setBinaryStream(2, null as InputStream?) + setCharacterStream(3, null as Reader?) + executeQuery() + } + verify(database).query(any(), any>()) + } + + @Test + fun executeBatchWithSelectStatement(): Unit = JdbcPreparedStatement( + connection, + database, + "SELECT * FROM test WHERE id=?" + ).run { + setInt(1, 1) + addBatch() + assertFailsWith { + executeBatch() + } + } + + @Test + fun executeBatchWithException() { + whenever(database.batch(any(), any>>())) doThrow SQLException("Batch failed") + JdbcPreparedStatement(connection, database, "UPDATE test SET value=? WHERE id=?").run { + setInt(1, 100) + setInt(2, 1) + addBatch() + assertFailsWith { + executeBatch() + } + } + } + + @Test + fun largeBatch() { + whenever(database.batch(any(), any>>())) doReturn 50 + val results = JdbcPreparedStatement( + connection, + database, + "UPDATE test SET value=? WHERE id=?", + ResultSet.TYPE_FORWARD_ONLY, + ResultSet.CONCUR_READ_ONLY, + ResultSet.CLOSE_CURSORS_AT_COMMIT + ).run { + for (i in 1..50) { + setInt(1, i * 10) + setInt(2, i) + addBatch() + } + executeBatch() + } + assertEquals(50, results.size) + verify(database).batch(any(), any>>()) + } + + @Test + fun unsupportedBlobMethods(): Unit = JdbcPreparedStatement( + connection, + database, + "SELECT * FROM test WHERE blob=?" + ).run { + assertFailsWith { + setBlob(1, "data".byteInputStream(), 4L) + } + assertFailsWith { + setBlob(1, "data".byteInputStream()) + } + } + + @Test + fun unsupportedNCharacterStreamWithLength(): Unit = JdbcPreparedStatement( + connection, + database, + "SELECT * FROM test WHERE text=?" + ).run { + assertFailsWith { + setNCharacterStream(1, "data".reader(), 4L) + } + } + + @Test + fun unsupportedNClobMethods(): Unit = JdbcPreparedStatement( + connection, + database, + "SELECT * FROM test WHERE clob=?" + ).run { + assertFailsWith { + setNClob(1, "data".reader(), 4L) + } + assertFailsWith { + setNClob(1, "data".reader()) + } + } + + @Test + fun parameterIndexOutOfRange(): Unit = JdbcPreparedStatement( + connection, + database, + "SELECT * FROM test WHERE id=?" + ).run { + assertFailsWith { + setInt(3, 42) + } + assertFailsWith { + setString(0, "test") + } + } + + @Test + fun clearParametersAfterClose(): Unit = preparedStatement.run { + close() + assertFailsWith { + clearParameters() + } + } + + @Test + fun addBatchAfterClose(): Unit = preparedStatement.run { + close() + assertFailsWith { + addBatch() + } + } + + @Test + fun clearBatchAfterClose(): Unit = preparedStatement.run { + close() + assertFailsWith { + clearBatch() + } + } + + @Test + fun executeBatchAfterClose(): Unit = preparedStatement.run { + close() + assertFailsWith { + executeBatch() + } + } + + @Test + fun executeUpdateAfterClose(): Unit = preparedStatement.run { + close() + assertFailsWith { + executeUpdate() + } + } + + @Test + fun executeAfterClose(): Unit = preparedStatement.run { + close() + assertFailsWith { + execute() + } + } +} diff --git a/selekt-jdbc/src/test/kotlin/com/bloomberg/selekt/jdbc/statement/JdbcStatementTest.kt b/selekt-jdbc/src/test/kotlin/com/bloomberg/selekt/jdbc/statement/JdbcStatementTest.kt new file mode 100644 index 0000000000..f8803e0fe8 --- /dev/null +++ b/selekt-jdbc/src/test/kotlin/com/bloomberg/selekt/jdbc/statement/JdbcStatementTest.kt @@ -0,0 +1,530 @@ +/* + * Copyright 2026 Bloomberg Finance L.P. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.bloomberg.selekt.jdbc.statement + +import com.bloomberg.selekt.ICursor +import com.bloomberg.selekt.ISQLStatement +import com.bloomberg.selekt.SQLDatabase +import com.bloomberg.selekt.jdbc.connection.JdbcConnection +import com.bloomberg.selekt.jdbc.result.JdbcResultSet +import com.bloomberg.selekt.jdbc.util.ConnectionURL +import java.sql.ResultSet +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.mockito.kotlin.any +import org.mockito.kotlin.doThrow +import org.mockito.kotlin.isNull +import org.mockito.kotlin.mock +import org.mockito.kotlin.whenever +import java.sql.SQLException +import java.sql.Statement +import java.util.Properties +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertFalse +import kotlin.test.assertNotNull +import kotlin.test.assertNull +import kotlin.test.assertSame +import kotlin.test.assertTrue +import org.mockito.kotlin.doReturn + +internal class JdbcStatementTest { + private lateinit var mockDatabase: SQLDatabase + private lateinit var mockConnection: JdbcConnection + private lateinit var mockCursor: ICursor + private lateinit var statement: JdbcStatement + + @BeforeEach + fun setUp() { + mockDatabase = mock() + mockCursor = mock() + + val connectionURL = ConnectionURL.parse("jdbc:selekt:/tmp/test.db") + val properties = Properties() + mockConnection = JdbcConnection(mockDatabase, connectionURL, properties) + + statement = JdbcStatement(mockConnection, mockDatabase) + } + + @Test + fun executeQuery() { + val sql = "SELECT * FROM users" + whenever(mockDatabase.query(sql, emptyArray())) doReturn mockCursor + + val resultSet = statement.executeQuery(sql) + assertNotNull(resultSet) + assertTrue(resultSet is JdbcResultSet) + assertEquals(resultSet, statement.resultSet) + } + + @Test + fun executeUpdate() { + val sql = "INSERT INTO users (name) VALUES ('test')" + val mockStatement = mock() + whenever(mockDatabase.compileStatement(sql, null)) doReturn mockStatement + whenever(mockStatement.executeUpdateDelete()) doReturn 1 + + val updateCount = statement.executeUpdate(sql) + assertEquals(1, updateCount) + assertEquals(1, statement.updateCount) + } + + @Test + fun executeWithQuery() { + val sql = "SELECT COUNT(*) FROM users" + whenever(mockDatabase.query(sql, emptyArray())) doReturn mockCursor + assertTrue(statement.execute(sql)) + assertNotNull(statement.resultSet) + assertEquals(-1, statement.updateCount) + } + + @Test + fun executeWithUpdate() { + val sql = "UPDATE users SET name = 'updated'" + val mockStatement = mock { + whenever(it.executeUpdateDelete()) doReturn 3 + } + whenever(mockDatabase.compileStatement(sql, null)) doReturn mockStatement + val result = statement.execute(sql) + assertFalse(result) + assertNull(statement.resultSet) + assertEquals(3, statement.updateCount) + } + + @Test + fun executeBatch() { + statement.apply { + addBatch("INSERT INTO users (name) VALUES ('user1')") + addBatch("INSERT INTO users (name) VALUES ('user2')") + } + val mockStatement = mock { + whenever(it.executeUpdateDelete()) doReturn 1 + } + whenever(mockDatabase.compileStatement(any(), isNull())) doReturn mockStatement + statement.executeBatch().run { + assertEquals(2, size) + assertEquals(1, this[0]) + assertEquals(1, this[1]) + } + } + + @Test + fun clearBatch() { + assertEquals(0, statement.apply { + addBatch("INSERT INTO users (name) VALUES ('test')") + clearBatch() + }.executeBatch().size) + } + + @Test + fun closure(): Unit = statement.run { + assertFalse(isClosed) + close() + assertTrue(isClosed) + assertFailsWith { + executeQuery("SELECT 1") + } + } + + @Test + fun maxFieldSize(): Unit = statement.run { + assertEquals(0, maxFieldSize) + maxFieldSize = 1000 + assertEquals(1000, maxFieldSize) + assertFailsWith { + maxFieldSize = -1 + } + } + + @Test + fun maxRows(): Unit = statement.run { + assertEquals(0, maxRows) + maxRows = 100 + assertEquals(100, maxRows) + assertFailsWith { + maxRows = -1 + } + } + + @Test + fun escapeProcessing(): Unit = statement.run { + assertTrue(escapeProcessing) + setEscapeProcessing(false) + assertFalse(escapeProcessing) + } + + @Test + fun queryTimeout(): Unit = statement.run { + assertEquals(0, queryTimeout) + queryTimeout = 30 + assertEquals(30, queryTimeout) + assertFailsWith { + queryTimeout = -1 + } + } + + @Test + fun fetchDirection(): Unit = statement.run { + assertEquals(ResultSet.FETCH_FORWARD, fetchDirection) + fetchDirection = ResultSet.FETCH_FORWARD + assertEquals(ResultSet.FETCH_FORWARD, fetchDirection) + assertFailsWith { + fetchDirection = ResultSet.FETCH_REVERSE + } + } + + @Test + fun fetchSize(): Unit = statement.run { + assertEquals(0, fetchSize) + fetchSize = 100 + assertEquals(100, fetchSize) + assertFailsWith { + fetchSize = -1 + } + } + + @Test + fun resultSetConcurrency() { + assertEquals(ResultSet.CONCUR_READ_ONLY, statement.resultSetConcurrency) + } + + @Test + fun resultSetType() { + assertEquals(ResultSet.TYPE_FORWARD_ONLY, statement.resultSetType) + } + + @Test + fun resultSetHoldability() { + assertEquals(ResultSet.CLOSE_CURSORS_AT_COMMIT, statement.resultSetHoldability) + } + + @Test + fun getConnection() { + assertEquals(mockConnection, statement.connection) + } + + @Test + fun getWarnings() { + assertNull(statement.warnings) + } + + @Test + fun clearWarnings() { + statement.clearWarnings() + assertNull(statement.warnings) + } + + @Test + fun cancellation() { + statement.cancel() + } + + @Test + fun getMoreResults(): Unit = statement.run { + assertFalse(getMoreResults()) + assertFalse(getMoreResults(Statement.CLOSE_CURRENT_RESULT)) + } + + @Test + fun getGeneratedKeys() { + assertFailsWith { + statement.generatedKeys + } + } + + @Test + fun executeUpdateWithGeneratedKeys() { + assertFailsWith { + statement.executeUpdate("INSERT INTO users (name) VALUES ('test')", Statement.RETURN_GENERATED_KEYS) + } + } + + @Test + fun executeWithGeneratedKeys() { + assertFailsWith { + statement.execute("INSERT INTO users (name) VALUES ('test')", Statement.RETURN_GENERATED_KEYS) + } + } + + @Test + fun wrapperInterface(): Unit = statement.run { + assertTrue(isWrapperFor(JdbcStatement::class.java)) + assertFalse(isWrapperFor(String::class.java)) + val unwrapped = unwrap(JdbcStatement::class.java) + assertSame(this, unwrapped) + assertFailsWith { + unwrap(String::class.java) + } + } + + @Test + fun poolable(): Unit = statement.run { + assertFalse(isPoolable) + isPoolable = true + assertTrue(isPoolable) + } + + @Test + fun closeOnCompletion(): Unit = statement.run { + assertFalse(isCloseOnCompletion) + closeOnCompletion() + assertTrue(isCloseOnCompletion) + } + + @Test + fun invalidOperationsAfterClose(): Unit = statement.run { + close() + assertFailsWith { + executeUpdate("INSERT INTO users (name) VALUES ('test')") + } + assertFailsWith { + execute("SELECT 1") + } + assertFailsWith { + addBatch("INSERT INTO users (name) VALUES ('test')") + } + assertFailsWith { + executeBatch() + } + } + + @Test + fun queryTypeDetection() { + val queries = listOf( + "SELECT * FROM users" to true, + " select count(*) from users " to true, + "WITH cte AS (SELECT * FROM users) SELECT * FROM cte" to true, + "INSERT INTO users (name) VALUES ('test')" to false, + "UPDATE users SET name = 'updated'" to false, + "DELETE FROM users WHERE id = 1" to false, + "CREATE TABLE test (id INTEGER)" to false, + "DROP TABLE test" to false + ) + whenever(mockDatabase.query(any(), any>())) doReturn mockCursor + val mockSqlStatement = mock { + whenever(it.executeUpdateDelete()) doReturn 1 + } + whenever(mockDatabase.compileStatement(any(), isNull())) doReturn mockSqlStatement + queries.forEach { (sql, isQuery) -> + val result = statement.execute(sql) + assertEquals(isQuery, result, "SQL: $sql should ${if (isQuery) { + "produce a result set" + } else { + "be an update" + }}") + } + } + + @Test + fun pragmaQuery() { + whenever(mockDatabase.query(any(), any>())) doReturn mockCursor + assertTrue(statement.execute("PRAGMA table_info(users)")) + assertNotNull(statement.resultSet) + } + + @Test + fun explainQuery() { + whenever(mockDatabase.query(any(), any>())) doReturn mockCursor + assertTrue(statement.execute("EXPLAIN SELECT * FROM users")) + assertNotNull(statement.resultSet) + } + + @Test + fun executeUpdateOverloads() { + val sqlStatement = mock { + whenever(it.executeUpdateDelete()) doReturn 1 + } + whenever(mockDatabase.compileStatement(any(), isNull())) doReturn sqlStatement + statement.run { + assertEquals(1, executeUpdate("INSERT INTO users (name) VALUES ('test')", Statement.NO_GENERATED_KEYS)) + assertEquals(1, executeUpdate("INSERT INTO users (name) VALUES ('test')", intArrayOf(1))) + assertEquals(1, executeUpdate("INSERT INTO users (name) VALUES ('test')", arrayOf("id"))) + } + } + + @Test + fun executeOverloads() { + whenever(mockDatabase.query(any(), any>())) doReturn mockCursor + statement.run { + assertTrue(execute("SELECT * FROM users", Statement.NO_GENERATED_KEYS)) + assertTrue(execute("SELECT * FROM users", intArrayOf(1))) + assertTrue(execute("SELECT * FROM users", arrayOf("id"))) + } + } + + @Test + fun setCursorName() { + assertFailsWith { + statement.setCursorName("test") + } + } + + @Test + fun unwrapToSQLDatabase() { + assertSame(mockDatabase, statement.unwrap(SQLDatabase::class.java)) + } + + @Test + fun isWrapperForSQLDatabase() { + assertTrue(statement.isWrapperFor(SQLDatabase::class.java)) + } + + @Test + fun currentResultSetClearedOnClose() { + whenever(mockDatabase.query(any(), any>())) doReturn mockCursor + statement.run { + executeQuery("SELECT * FROM users") + assertNotNull(resultSet) + close() + assertNull(resultSet) + } + } + + @Test + fun getResultSetAfterUpdate() { + val sqlStatement = mock { + whenever(it.executeUpdateDelete()) doReturn 1 + } + whenever(mockDatabase.compileStatement(any(), isNull())) doReturn sqlStatement + statement.run { + executeUpdate("INSERT INTO users (name) VALUES ('test')") + assertNull(resultSet) + } + } + + @Test + fun getUpdateCountAfterQuery() { + whenever(mockDatabase.query(any(), any>())) doReturn mockCursor + statement.run { + executeQuery("SELECT * FROM users") + assertEquals(-1, updateCount) + } + } + + @Test + fun closeIdempotent(): Unit = statement.run { + assertFalse(isClosed) + close() + assertTrue(isClosed) + close() + assertTrue(isClosed) + } + + @Test + fun executeQueryException() { + whenever(mockDatabase.query(any(), any>())) doThrow RuntimeException("Database error") + assertFailsWith { + statement.executeQuery("SELECT * FROM users") + } + } + + @Test + fun executeUpdateException() { + whenever(mockDatabase.compileStatement(any(), isNull())) doThrow RuntimeException("Database error") + assertFailsWith { + statement.executeUpdate("INSERT INTO users (name) VALUES ('test')") + } + } + + @Test + fun executeBatchWithException() { + whenever(mockDatabase.compileStatement(any(), isNull())) doThrow RuntimeException("Database error") + statement.run { + addBatch("INSERT INTO users (name) VALUES ('test')") + assertFailsWith { + executeBatch() + } + } + } + + @Test + fun addBatchWithBlankSql(): Unit = statement.run { + assertFailsWith { + addBatch("") + } + assertFailsWith { + addBatch(" ") + } + } + + @Test + fun executeBatchWithSelectStatement(): Unit = statement.run { + addBatch("SELECT * FROM users") + assertFailsWith { + executeBatch() + } + } + + @Test + fun executeBatchClearsOnFailure() { + val sqlStatement = mock { + whenever(it.executeUpdateDelete()) doReturn 1 + } + mockDatabase.apply { + whenever(compileStatement("INSERT INTO users (name) VALUES ('test')", null)) doReturn sqlStatement + whenever(compileStatement("INVALID SQL", null)) doThrow SQLException("Syntax error") + } + statement.run { + addBatch("INSERT INTO users (name) VALUES ('test')") + addBatch("INVALID SQL") + assertFailsWith { + executeBatch() + } + assertEquals(0, executeBatch().size) + } + } + + @Test + fun setPoolableWhenClosed(): Unit = statement.run { + close() + assertFailsWith { + isPoolable = true + } + } + + @Test + fun setMaxFieldSizeWhenClosed(): Unit = statement.run { + close() + assertFailsWith { + maxFieldSize = 100 + } + } + + @Test + fun addBatchWhenClosed(): Unit = statement.run { + close() + assertFailsWith { + addBatch("INSERT INTO users (name) VALUES ('test')") + } + } + + @Test + fun clearBatchWhenClosed(): Unit = statement.run { + close() + assertFailsWith { + clearBatch() + } + } + + @Test + fun closeOnCompletionWhenClosed(): Unit = statement.run { + close() + assertFailsWith { + closeOnCompletion() + } + } +} diff --git a/selekt-jdbc/src/test/kotlin/com/bloomberg/selekt/jdbc/util/ConnectionURLTest.kt b/selekt-jdbc/src/test/kotlin/com/bloomberg/selekt/jdbc/util/ConnectionURLTest.kt new file mode 100644 index 0000000000..377fe07ae5 --- /dev/null +++ b/selekt-jdbc/src/test/kotlin/com/bloomberg/selekt/jdbc/util/ConnectionURLTest.kt @@ -0,0 +1,105 @@ +/* + * Copyright 2026 Bloomberg Finance L.P. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.bloomberg.selekt.jdbc.util + +import org.junit.jupiter.api.Test +import java.sql.SQLException +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +internal class ConnectionURLTest { + @Test + fun basicURL(): Unit = ConnectionURL.parse("jdbc:selekt:/path/to/test.db").run { + assertEquals("/path/to/test.db", databasePath) + assertTrue(properties.isEmpty) + } + + @Test + fun urlWithProperties(): Unit = ConnectionURL.parse( + "jdbc:selekt:/path/to/test.db?encrypt=true&key=abc123&poolSize=5" + ).run { + assertEquals("/path/to/test.db", databasePath) + assertEquals("true", getProperty("encrypt")) + assertEquals("abc123", getProperty("key")) + assertEquals("5", getProperty("poolSize")) + } + + @Test + fun booleanProperties(): Unit = ConnectionURL.parse("jdbc:selekt:/test.db?encrypt=true&foreignKeys=false").run { + assertTrue(getBooleanProperty("encrypt")) + assertFalse(getBooleanProperty("foreignKeys")) + assertFalse(getBooleanProperty("nonexistent")) + } + + @Test + fun intProperties(): Unit = ConnectionURL.parse("jdbc:selekt:/test.db?poolSize=10&busyTimeout=5000").run { + assertEquals(10, getIntProperty("poolSize")) + assertEquals(5_000, getIntProperty("busyTimeout")) + assertEquals(0, getIntProperty("nonexistent")) + assertEquals(42, getIntProperty("nonexistent", 42)) + } + + @Test + fun invalidURL() { + assertFailsWith { + ConnectionURL.parse("invalid:url") + } + assertFailsWith { + ConnectionURL.parse("jdbc:other:/test.db") + } + } + + @Test + fun emptyPath() { + assertFailsWith { + ConnectionURL.parse("jdbc:selekt:") + } + } + + @Test + fun urlValidation() { + listOf( + "jdbc:selekt:/test.db", + "jdbc:selekt:/path/to/test.db?prop=value" + ).forEach { + assertTrue(ConnectionURL.isValidUrl(it)) + } + listOf( + "jdbc:other:/test.db", + "invalid", + null + ).forEach { + assertFalse(ConnectionURL.isValidUrl(it)) + } + } + + @Test + fun connectionToString(): Unit = ConnectionURL.parse( + "jdbc:selekt:/test.db?encrypt=true&poolSize=10" + ).toString().run { + assertTrue(startsWith("jdbc:selekt:/test.db")) + assertTrue(contains("encrypt=true")) + assertTrue(contains("poolSize=10")) + } + + @Test + fun urlEncoding() { + assertEquals("hello world", ConnectionURL.parse("jdbc:selekt:/test.db?key=hello%20world").getProperty("key")) + } +} diff --git a/selekt-jdbc/src/test/kotlin/com/bloomberg/selekt/jdbc/util/TypeMappingTest.kt b/selekt-jdbc/src/test/kotlin/com/bloomberg/selekt/jdbc/util/TypeMappingTest.kt new file mode 100644 index 0000000000..2dacefcae2 --- /dev/null +++ b/selekt-jdbc/src/test/kotlin/com/bloomberg/selekt/jdbc/util/TypeMappingTest.kt @@ -0,0 +1,473 @@ +/* + * Copyright 2026 Bloomberg Finance L.P. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.bloomberg.selekt.jdbc.util + +import com.bloomberg.selekt.ColumnType +import org.junit.jupiter.api.Test +import java.math.BigDecimal +import java.sql.Date +import java.sql.Time +import java.sql.Timestamp +import java.sql.Types +import kotlin.test.assertEquals +import kotlin.test.assertNull +import kotlin.test.assertTrue + +internal class TypeMappingTest { + @Test + fun selektToJdbcTypeMapping() { + assertEquals(Types.BIGINT, TypeMapping.toJdbcType(ColumnType.INTEGER)) + assertEquals(Types.DOUBLE, TypeMapping.toJdbcType(ColumnType.FLOAT)) + assertEquals(Types.VARCHAR, TypeMapping.toJdbcType(ColumnType.STRING)) + assertEquals(Types.VARBINARY, TypeMapping.toJdbcType(ColumnType.BLOB)) + assertEquals(Types.NULL, TypeMapping.toJdbcType(ColumnType.NULL)) + } + + @Test + fun jdbcToSelektTypeMapping() { + assertEquals(ColumnType.INTEGER, TypeMapping.toSelektType(Types.INTEGER)) + assertEquals(ColumnType.INTEGER, TypeMapping.toSelektType(Types.BIGINT)) + assertEquals(ColumnType.INTEGER, TypeMapping.toSelektType(Types.BOOLEAN)) + + assertEquals(ColumnType.FLOAT, TypeMapping.toSelektType(Types.DOUBLE)) + assertEquals(ColumnType.FLOAT, TypeMapping.toSelektType(Types.FLOAT)) + assertEquals(ColumnType.FLOAT, TypeMapping.toSelektType(Types.DECIMAL)) + + assertEquals(ColumnType.STRING, TypeMapping.toSelektType(Types.VARCHAR)) + assertEquals(ColumnType.STRING, TypeMapping.toSelektType(Types.DATE)) + assertEquals(ColumnType.STRING, TypeMapping.toSelektType(Types.TIMESTAMP)) + + assertEquals(ColumnType.BLOB, TypeMapping.toSelektType(Types.BINARY)) + assertEquals(ColumnType.BLOB, TypeMapping.toSelektType(Types.VARBINARY)) + + assertEquals(ColumnType.NULL, TypeMapping.toSelektType(Types.NULL)) + } + + @Test + fun convertFromSQLiteToBoolean() { + assertEquals(true, TypeMapping.convertFromSQLite(1L, Types.BOOLEAN)) + assertEquals(false, TypeMapping.convertFromSQLite(0L, Types.BOOLEAN)) + assertEquals(true, TypeMapping.convertFromSQLite("true", Types.BOOLEAN)) + assertEquals(true, TypeMapping.convertFromSQLite("1", Types.BOOLEAN)) + assertEquals(false, TypeMapping.convertFromSQLite("false", Types.BOOLEAN)) + assertEquals(false, TypeMapping.convertFromSQLite("other", Types.BOOLEAN)) + } + + @Test + fun convertFromSQLiteToInteger() { + assertEquals(42, TypeMapping.convertFromSQLite(42L, Types.INTEGER)) + assertEquals(42, TypeMapping.convertFromSQLite("42", Types.INTEGER)) + assertEquals(0, TypeMapping.convertFromSQLite("invalid", Types.INTEGER)) + assertEquals(42, TypeMapping.convertFromSQLite(42.7, Types.INTEGER)) + } + + @Test + fun convertFromSQLiteToString() { + assertEquals("hello", TypeMapping.convertFromSQLite("hello", Types.VARCHAR)) + assertEquals("42", TypeMapping.convertFromSQLite(42L, Types.VARCHAR)) + assertEquals("3.14", TypeMapping.convertFromSQLite(3.14, Types.VARCHAR)) + } + + @Test + fun convertFromSQLiteToDecimal() { + assertEquals(3.14159, (TypeMapping.convertFromSQLite(3.14159, Types.DECIMAL) as BigDecimal).toDouble(), 0.000001) + assertEquals(123.45, (TypeMapping.convertFromSQLite("123.45", Types.DECIMAL) as BigDecimal).toDouble(), 0.000001) + assertEquals(BigDecimal.ZERO, TypeMapping.convertFromSQLite("invalid", Types.DECIMAL) as BigDecimal) + } + + @Test + fun convertFromSQLiteToDate() { + assertTrue(TypeMapping.convertFromSQLite("2025-12-25", Types.DATE) is Date) + assertNull(TypeMapping.convertFromSQLite("invalid-date", Types.DATE)) + } + + @Test + fun convertFromSQLiteToTime() { + assertTrue(TypeMapping.convertFromSQLite("10:30:45", Types.TIME) is Time) + assertNull(TypeMapping.convertFromSQLite("invalid-time", Types.TIME)) + } + + @Test + fun convertFromSQLiteToTimestamp() { + listOf( + "2025-12-25T10:30:45", + "2025-12-25 10:30:45", + "1640995200000" + ).forEach { + assertTrue(TypeMapping.convertFromSQLite(it, Types.TIMESTAMP) is Timestamp) + } + assertNull(TypeMapping.convertFromSQLite("invalid-timestamp", Types.TIMESTAMP)) + } + + @Test + fun convertFromSQLiteToByteArray() { + assertEquals(4, (TypeMapping.convertFromSQLite(byteArrayOf(1, 2, 3, 4), Types.VARBINARY) as ByteArray).size) + assertTrue( + (TypeMapping.convertFromSQLite("hello", Types.VARBINARY) as ByteArray) + .contentEquals("hello".toByteArray(Charsets.UTF_8)) + ) + } + + @Test + fun convertToSQLite() { + assertEquals(1L, TypeMapping.convertToSQLite(true)) + assertEquals(0L, TypeMapping.convertToSQLite(false)) + assertEquals(42L, TypeMapping.convertToSQLite(42)) + assertEquals(42L, TypeMapping.convertToSQLite(42L)) + assertEquals(3.14, TypeMapping.convertToSQLite(3.14)) + assertEquals("hello", TypeMapping.convertToSQLite("hello")) + + val bytes = byteArrayOf(1, 2, 3) + assertEquals(bytes, TypeMapping.convertToSQLite(bytes)) + + assertNull(TypeMapping.convertToSQLite(null)) + } + + @Test + fun getJavaClassName() { + assertEquals(Boolean::class.java.name, TypeMapping.getJavaClassName(Types.BOOLEAN)) + assertEquals(Int::class.java.name, TypeMapping.getJavaClassName(Types.INTEGER)) + assertEquals(Long::class.java.name, TypeMapping.getJavaClassName(Types.BIGINT)) + assertEquals(Double::class.java.name, TypeMapping.getJavaClassName(Types.DOUBLE)) + assertEquals(String::class.java.name, TypeMapping.getJavaClassName(Types.VARCHAR)) + assertEquals(ByteArray::class.java.name, TypeMapping.getJavaClassName(Types.VARBINARY)) + } + + @Test + fun getJdbcTypeName() { + assertEquals("BOOLEAN", TypeMapping.getJdbcTypeName(Types.BOOLEAN)) + assertEquals("INTEGER", TypeMapping.getJdbcTypeName(Types.INTEGER)) + assertEquals("VARCHAR", TypeMapping.getJdbcTypeName(Types.VARCHAR)) + assertEquals("VARBINARY", TypeMapping.getJdbcTypeName(Types.VARBINARY)) + assertEquals("NULL", TypeMapping.getJdbcTypeName(Types.NULL)) + assertEquals("OTHER", TypeMapping.getJdbcTypeName(999_999)) + } + + @Test + fun getPrecisionAndScale() { + assertEquals(19, TypeMapping.getPrecision(Types.BIGINT)) + assertEquals(0, TypeMapping.getScale(Types.BIGINT)) + + assertEquals(15, TypeMapping.getPrecision(Types.DOUBLE)) + assertEquals(15, TypeMapping.getScale(Types.DOUBLE)) + + assertEquals(0, TypeMapping.getPrecision(Types.VARCHAR)) + assertEquals(0, TypeMapping.getScale(Types.VARCHAR)) + } + + @Test + fun convertFromSQLiteWithNullValue() { + assertNull(TypeMapping.convertFromSQLite(null, Types.INTEGER)) + assertNull(TypeMapping.convertFromSQLite(null, Types.VARCHAR)) + assertNull(TypeMapping.convertFromSQLite(null, Types.BOOLEAN)) + } + + @Test + fun convertFromSQLiteToTinyInt() { + assertEquals(42.toByte(), TypeMapping.convertFromSQLite(42L, Types.TINYINT)) + assertEquals(42.toByte(), TypeMapping.convertFromSQLite("42", Types.TINYINT)) + assertEquals(0.toByte(), TypeMapping.convertFromSQLite("invalid", Types.TINYINT)) + assertEquals(0.toByte(), TypeMapping.convertFromSQLite(Any(), Types.TINYINT)) + } + + @Test + fun convertFromSQLiteToSmallInt() { + assertEquals(42.toShort(), TypeMapping.convertFromSQLite(42L, Types.SMALLINT)) + assertEquals(42.toShort(), TypeMapping.convertFromSQLite("42", Types.SMALLINT)) + assertEquals(0.toShort(), TypeMapping.convertFromSQLite("invalid", Types.SMALLINT)) + assertEquals(0.toShort(), TypeMapping.convertFromSQLite(Any(), Types.SMALLINT)) + } + + @Test + fun convertFromSQLiteToBigInt() { + assertEquals(42L, TypeMapping.convertFromSQLite(42, Types.BIGINT)) + assertEquals(42L, TypeMapping.convertFromSQLite("42", Types.BIGINT)) + assertEquals(0L, TypeMapping.convertFromSQLite("invalid", Types.BIGINT)) + assertEquals(0L, TypeMapping.convertFromSQLite(Any(), Types.BIGINT)) + } + + @Test + fun convertFromSQLiteToFloat() { + assertEquals(3.14f, TypeMapping.convertFromSQLite(3.14, Types.FLOAT)) + assertEquals(42f, TypeMapping.convertFromSQLite(42L, Types.FLOAT)) + assertEquals(3.14f, TypeMapping.convertFromSQLite("3.14", Types.FLOAT)) + assertEquals(0f, TypeMapping.convertFromSQLite("invalid", Types.FLOAT)) + assertEquals(0f, TypeMapping.convertFromSQLite(Any(), Types.FLOAT)) + } + + @Test + fun convertFromSQLiteToDouble() { + assertEquals(3.14, TypeMapping.convertFromSQLite(3.14, Types.DOUBLE)) + assertEquals(42.0, TypeMapping.convertFromSQLite(42L, Types.DOUBLE)) + assertEquals(3.14, TypeMapping.convertFromSQLite("3.14", Types.DOUBLE)) + assertEquals(0.0, TypeMapping.convertFromSQLite("invalid", Types.DOUBLE)) + assertEquals(0.0, TypeMapping.convertFromSQLite(Any(), Types.DOUBLE)) + } + + @Test + fun convertFromSQLiteToNumeric() { + val resultOne = TypeMapping.convertFromSQLite(3.14159, Types.NUMERIC) as BigDecimal + assertEquals(3.14159, resultOne.toDouble(), 0.000001) + val resultTwo = TypeMapping.convertFromSQLite("123.45", Types.NUMERIC) as BigDecimal + assertEquals(123.45, resultTwo.toDouble(), 0.000001) + assertEquals(BigDecimal.ZERO, TypeMapping.convertFromSQLite("invalid", Types.NUMERIC)) + assertEquals(BigDecimal.ZERO, TypeMapping.convertFromSQLite(Any(), Types.NUMERIC)) + } + + @Test + fun convertFromSQLiteDateFromNumber() { + val epochMillis = 1_640_995_200_000L + val date = TypeMapping.convertFromSQLite(epochMillis, Types.DATE) as Date + assertEquals(epochMillis, date.time) + } + + @Test + fun convertFromSQLiteTimeFromNumber() { + val epochMillis = 37_845_000L + val time = TypeMapping.convertFromSQLite(epochMillis, Types.TIME) as Time + assertEquals(epochMillis, time.time) + } + + @Test + fun convertFromSQLiteTimestampFromNumber() { + val epochMillis = 1_640_995_200_000L + val timestamp = TypeMapping.convertFromSQLite(epochMillis, Types.TIMESTAMP) as Timestamp + assertEquals(epochMillis, timestamp.time) + } + + @Test + fun convertFromSQLiteDateWithTimestampFallback() { + val date = TypeMapping.convertFromSQLite("2025-12-25 10:30:45", Types.DATE) + assertTrue(date is Date) + } + + @Test + fun convertFromSQLiteUnknownType() { + val value = "test" + assertEquals(value, TypeMapping.convertFromSQLite(value, 999_999)) + } + + @Test + fun convertFromSQLiteNonStringTypes() { + assertEquals("test", TypeMapping.convertFromSQLite("test", Types.CHAR)) + assertEquals("test", TypeMapping.convertFromSQLite("test", Types.LONGVARCHAR)) + assertEquals("test", TypeMapping.convertFromSQLite("test", Types.NCHAR)) + assertEquals("test", TypeMapping.convertFromSQLite("test", Types.NVARCHAR)) + assertEquals("test", TypeMapping.convertFromSQLite("test", Types.LONGNVARCHAR)) + assertEquals("test", TypeMapping.convertFromSQLite("test", Types.CLOB)) + assertEquals("test", TypeMapping.convertFromSQLite("test", Types.NCLOB)) + } + + @Test + fun convertBinaryTypes() { + val bytes = byteArrayOf(1, 2, 3) + assertTrue((TypeMapping.convertFromSQLite(bytes, Types.BINARY) as ByteArray).contentEquals(bytes)) + assertTrue((TypeMapping.convertFromSQLite(bytes, Types.LONGVARBINARY) as ByteArray).contentEquals(bytes)) + assertEquals(0, (TypeMapping.convertFromSQLite(Any(), Types.BINARY) as ByteArray).size) + } + + @Test + fun convertToBooleanFromAlreadyBoolean() { + assertEquals(true, TypeMapping.convertFromSQLite(true, Types.BOOLEAN)) + assertEquals(false, TypeMapping.convertFromSQLite(false, Types.BOOLEAN)) + } + + @Test + fun convertToSQLiteFromByte() { + assertEquals(42L, TypeMapping.convertToSQLite(42.toByte())) + } + + @Test + fun convertToSQLiteFromShort() { + assertEquals(42L, TypeMapping.convertToSQLite(42.toShort())) + } + + @Test + fun convertToSQLiteFromFloat() { + assertEquals(3.14f.toDouble(), TypeMapping.convertToSQLite(3.14f)) + } + + @Test + fun convertToSQLiteFromBigDecimal() { + assertEquals(123.45, TypeMapping.convertToSQLite(BigDecimal("123.45"))) + } + + @Test + fun convertToSQLiteFromDate() { + val date = Date.valueOf("2025-12-25") + assertEquals("2025-12-25", TypeMapping.convertToSQLite(date)) + } + + @Test + fun convertToSQLiteFromTime() { + val time = Time.valueOf("10:30:45") + assertEquals("10:30:45", TypeMapping.convertToSQLite(time)) + } + + @Test + fun convertToSQLiteFromTimestamp() { + val timestamp = Timestamp.valueOf("2025-12-25 10:30:45") + assertEquals("2025-12-25 10:30:45.0", TypeMapping.convertToSQLite(timestamp)) + } + + @Test + fun convertToSQLiteFromLocalDate() { + val localDate = java.time.LocalDate.of(2025, 12, 25) + assertEquals("2025-12-25", TypeMapping.convertToSQLite(localDate)) + } + + @Test + fun convertToSQLiteFromLocalTime() { + val localTime = java.time.LocalTime.of(10, 30, 45) + assertEquals("10:30:45", TypeMapping.convertToSQLite(localTime)) + } + + @Test + fun convertToSQLiteFromLocalDateTime() { + val localDateTime = java.time.LocalDateTime.of(2025, 12, 25, 10, 30, 45) + assertEquals("2025-12-25T10:30:45", TypeMapping.convertToSQLite(localDateTime)) + } + + @Test + fun convertToSQLiteFromUnknownObject() { + val obj = object : Any() { + override fun toString() = "custom" + } + assertEquals("custom", TypeMapping.convertToSQLite(obj)) + } + + @Test + fun jdbcToSelektTypeAdditionalMappings() { + assertEquals(ColumnType.INTEGER, TypeMapping.toSelektType(Types.TINYINT)) + assertEquals(ColumnType.INTEGER, TypeMapping.toSelektType(Types.SMALLINT)) + assertEquals(ColumnType.FLOAT, TypeMapping.toSelektType(Types.REAL)) + assertEquals(ColumnType.FLOAT, TypeMapping.toSelektType(Types.NUMERIC)) + assertEquals(ColumnType.STRING, TypeMapping.toSelektType(Types.CHAR)) + assertEquals(ColumnType.STRING, TypeMapping.toSelektType(Types.LONGVARCHAR)) + assertEquals(ColumnType.STRING, TypeMapping.toSelektType(Types.NCHAR)) + assertEquals(ColumnType.STRING, TypeMapping.toSelektType(Types.NVARCHAR)) + assertEquals(ColumnType.STRING, TypeMapping.toSelektType(Types.LONGNVARCHAR)) + assertEquals(ColumnType.STRING, TypeMapping.toSelektType(Types.CLOB)) + assertEquals(ColumnType.STRING, TypeMapping.toSelektType(Types.NCLOB)) + assertEquals(ColumnType.STRING, TypeMapping.toSelektType(Types.TIME)) + assertEquals(ColumnType.STRING, TypeMapping.toSelektType(Types.TIMESTAMP_WITH_TIMEZONE)) + assertEquals(ColumnType.BLOB, TypeMapping.toSelektType(Types.LONGVARBINARY)) + assertEquals(ColumnType.BLOB, TypeMapping.toSelektType(Types.BLOB)) + assertEquals(ColumnType.STRING, TypeMapping.toSelektType(999_999)) // Unknown type defaults to STRING + } + + @Test + fun getJavaClassNameAdditionalTypes() { + assertEquals(Byte::class.java.name, TypeMapping.getJavaClassName(Types.TINYINT)) + assertEquals(Short::class.java.name, TypeMapping.getJavaClassName(Types.SMALLINT)) + assertEquals(Float::class.java.name, TypeMapping.getJavaClassName(Types.FLOAT)) + assertEquals(Float::class.java.name, TypeMapping.getJavaClassName(Types.REAL)) + assertEquals(BigDecimal::class.java.name, TypeMapping.getJavaClassName(Types.NUMERIC)) + assertEquals(BigDecimal::class.java.name, TypeMapping.getJavaClassName(Types.DECIMAL)) + assertEquals(String::class.java.name, TypeMapping.getJavaClassName(Types.CHAR)) + assertEquals(String::class.java.name, TypeMapping.getJavaClassName(Types.LONGVARCHAR)) + assertEquals(String::class.java.name, TypeMapping.getJavaClassName(Types.NCHAR)) + assertEquals(String::class.java.name, TypeMapping.getJavaClassName(Types.NVARCHAR)) + assertEquals(String::class.java.name, TypeMapping.getJavaClassName(Types.LONGNVARCHAR)) + assertEquals(String::class.java.name, TypeMapping.getJavaClassName(Types.CLOB)) + assertEquals(String::class.java.name, TypeMapping.getJavaClassName(Types.NCLOB)) + assertEquals(Date::class.java.name, TypeMapping.getJavaClassName(Types.DATE)) + assertEquals(Time::class.java.name, TypeMapping.getJavaClassName(Types.TIME)) + assertEquals(Timestamp::class.java.name, TypeMapping.getJavaClassName(Types.TIMESTAMP)) + assertEquals(Timestamp::class.java.name, TypeMapping.getJavaClassName(Types.TIMESTAMP_WITH_TIMEZONE)) + assertEquals(ByteArray::class.java.name, TypeMapping.getJavaClassName(Types.BINARY)) + assertEquals(ByteArray::class.java.name, TypeMapping.getJavaClassName(Types.LONGVARBINARY)) + assertEquals(java.sql.Blob::class.java.name, TypeMapping.getJavaClassName(Types.BLOB)) + assertEquals(String::class.java.name, TypeMapping.getJavaClassName(999_999)) + } + + @Test + fun getJdbcTypeNameAllTypes() { + assertEquals("TINYINT", TypeMapping.getJdbcTypeName(Types.TINYINT)) + assertEquals("SMALLINT", TypeMapping.getJdbcTypeName(Types.SMALLINT)) + assertEquals("BIGINT", TypeMapping.getJdbcTypeName(Types.BIGINT)) + assertEquals("REAL", TypeMapping.getJdbcTypeName(Types.REAL)) + assertEquals("FLOAT", TypeMapping.getJdbcTypeName(Types.FLOAT)) + assertEquals("DOUBLE", TypeMapping.getJdbcTypeName(Types.DOUBLE)) + assertEquals("NUMERIC", TypeMapping.getJdbcTypeName(Types.NUMERIC)) + assertEquals("DECIMAL", TypeMapping.getJdbcTypeName(Types.DECIMAL)) + assertEquals("CHAR", TypeMapping.getJdbcTypeName(Types.CHAR)) + assertEquals("LONGVARCHAR", TypeMapping.getJdbcTypeName(Types.LONGVARCHAR)) + assertEquals("NCHAR", TypeMapping.getJdbcTypeName(Types.NCHAR)) + assertEquals("NVARCHAR", TypeMapping.getJdbcTypeName(Types.NVARCHAR)) + assertEquals("LONGNVARCHAR", TypeMapping.getJdbcTypeName(Types.LONGNVARCHAR)) + assertEquals("DATE", TypeMapping.getJdbcTypeName(Types.DATE)) + assertEquals("TIME", TypeMapping.getJdbcTypeName(Types.TIME)) + assertEquals("TIMESTAMP", TypeMapping.getJdbcTypeName(Types.TIMESTAMP)) + assertEquals("TIMESTAMP_WITH_TIMEZONE", TypeMapping.getJdbcTypeName(Types.TIMESTAMP_WITH_TIMEZONE)) + assertEquals("BINARY", TypeMapping.getJdbcTypeName(Types.BINARY)) + assertEquals("LONGVARBINARY", TypeMapping.getJdbcTypeName(Types.LONGVARBINARY)) + assertEquals("BLOB", TypeMapping.getJdbcTypeName(Types.BLOB)) + assertEquals("CLOB", TypeMapping.getJdbcTypeName(Types.CLOB)) + assertEquals("NCLOB", TypeMapping.getJdbcTypeName(Types.NCLOB)) + } + + @Test + fun getPrecisionAllTypes() { + assertEquals(1, TypeMapping.getPrecision(Types.BOOLEAN)) + assertEquals(3, TypeMapping.getPrecision(Types.TINYINT)) + assertEquals(5, TypeMapping.getPrecision(Types.SMALLINT)) + assertEquals(10, TypeMapping.getPrecision(Types.INTEGER)) + assertEquals(7, TypeMapping.getPrecision(Types.FLOAT)) + assertEquals(7, TypeMapping.getPrecision(Types.REAL)) + assertEquals(10, TypeMapping.getPrecision(Types.DATE)) + assertEquals(8, TypeMapping.getPrecision(Types.TIME)) + assertEquals(23, TypeMapping.getPrecision(Types.TIMESTAMP)) + assertEquals(0, TypeMapping.getPrecision(999_999)) + } + + @Test + fun getScaleAllTypes() { + assertEquals(7, TypeMapping.getScale(Types.FLOAT)) + assertEquals(7, TypeMapping.getScale(Types.REAL)) + assertEquals(15, TypeMapping.getScale(Types.DOUBLE)) + assertEquals(0, TypeMapping.getScale(Types.INTEGER)) + assertEquals(0, TypeMapping.getScale(Types.VARCHAR)) + assertEquals(0, TypeMapping.getScale(999_999)) + } + + @Test + fun convertTimestampWithTimezone() { + val timestampStr = "2025-12-25T10:30:45" + assertTrue(TypeMapping.convertFromSQLite(timestampStr, Types.TIMESTAMP_WITH_TIMEZONE) is Timestamp) + } + + @Test + fun convertFromSQLiteDateFromObject() { + assertNull(TypeMapping.convertFromSQLite(Any(), Types.DATE)) + } + + @Test + fun convertFromSQLiteTimeFromObject() { + assertNull(TypeMapping.convertFromSQLite(Any(), Types.TIME)) + } + + @Test + fun convertFromSQLiteTimestampFromObject() { + assertNull(TypeMapping.convertFromSQLite(Any(), Types.TIMESTAMP)) + } + + @Test + fun convertFromSQLiteBooleanFromObject() { + assertEquals(false, TypeMapping.convertFromSQLite(Any(), Types.BOOLEAN)) + } +} diff --git a/selekt-jvm/build.gradle.kts b/selekt-jvm/build.gradle.kts index e2e3389988..625079e761 100644 --- a/selekt-jvm/build.gradle.kts +++ b/selekt-jvm/build.gradle.kts @@ -28,7 +28,6 @@ plugins { `maven-publish` signing alias(libs.plugins.detekt) - alias(libs.plugins.ktlint) } repositories { diff --git a/selekt-sqlite3-api/build.gradle.kts b/selekt-sqlite3-api/build.gradle.kts index 821b2d499f..a31887b964 100644 --- a/selekt-sqlite3-api/build.gradle.kts +++ b/selekt-sqlite3-api/build.gradle.kts @@ -24,7 +24,6 @@ plugins { `maven-publish` signing alias(libs.plugins.detekt) - alias(libs.plugins.ktlint) } repositories { diff --git a/selekt-sqlite3-api/src/main/kotlin/com/bloomberg/selekt/IExternalSQLite.kt b/selekt-sqlite3-api/src/main/kotlin/com/bloomberg/selekt/IExternalSQLite.kt index abc4300941..c4e15a64a0 100644 --- a/selekt-sqlite3-api/src/main/kotlin/com/bloomberg/selekt/IExternalSQLite.kt +++ b/selekt-sqlite3-api/src/main/kotlin/com/bloomberg/selekt/IExternalSQLite.kt @@ -36,6 +36,8 @@ interface IExternalSQLite { fun bindParameterCount(statement: Long): Int + fun bindParameterIndex(statement: Long, name: String): Int + fun bindText(statement: Long, index: Int, value: String): SQLCode fun bindZeroBlob(statement: Long, index: Int, length: Int): SQLCode diff --git a/selekt-sqlite3-classes/build.gradle.kts b/selekt-sqlite3-classes/build.gradle.kts index 80bdb037f0..ed48cff4d8 100644 --- a/selekt-sqlite3-classes/build.gradle.kts +++ b/selekt-sqlite3-classes/build.gradle.kts @@ -30,7 +30,6 @@ plugins { alias(libs.plugins.jmh) alias(libs.plugins.kover) alias(libs.plugins.detekt) - alias(libs.plugins.ktlint) } repositories { diff --git a/selekt-sqlite3-classes/src/java17/kotlin/com/bloomberg/selekt/ExternalSQLite.kt b/selekt-sqlite3-classes/src/java17/kotlin/com/bloomberg/selekt/ExternalSQLite.kt index 23b6a5baef..18337b30c5 100644 --- a/selekt-sqlite3-classes/src/java17/kotlin/com/bloomberg/selekt/ExternalSQLite.kt +++ b/selekt-sqlite3-classes/src/java17/kotlin/com/bloomberg/selekt/ExternalSQLite.kt @@ -59,6 +59,8 @@ internal class ExternalSQLite( external override fun bindParameterCount(statement: Long): Int + external override fun bindParameterIndex(statement: Long, name: String): Int + external override fun bindText(statement: Long, index: Int, value: String): SQLCode external override fun bindZeroBlob(statement: Long, index: Int, length: Int): SQLCode diff --git a/selekt-sqlite3-classes/src/java25/kotlin/com/bloomberg/selekt/ExternalSQLite.kt b/selekt-sqlite3-classes/src/java25/kotlin/com/bloomberg/selekt/ExternalSQLite.kt index 8564213d02..4229fb817b 100644 --- a/selekt-sqlite3-classes/src/java25/kotlin/com/bloomberg/selekt/ExternalSQLite.kt +++ b/selekt-sqlite3-classes/src/java25/kotlin/com/bloomberg/selekt/ExternalSQLite.kt @@ -122,6 +122,16 @@ internal class ExternalSQLite( statement: Long ): Int = sqlite3_bind_parameter_count.invoke(MemorySegment.ofAddress(statement)) as Int + override fun bindParameterIndex( + statement: Long, + name: String + ): Int = Arena.ofConfined().use { + sqlite3_bind_parameter_index.invoke( + MemorySegment.ofAddress(statement), + it.allocateFrom(name) + ) as Int + } + override fun bindText( statement: Long, index: Int, @@ -616,12 +626,16 @@ internal class ExternalSQLite( private val sqliteTransient = MemorySegment.ofAddress(-1L) + private val criticalOption = Linker.Option.critical(true) + private val nonCriticalOption = Linker.Option.critical(false) + private val sqlite3_bind_blob: MethodHandle private val sqlite3_bind_double: MethodHandle private val sqlite3_bind_int: MethodHandle private val sqlite3_bind_int64: MethodHandle private val sqlite3_bind_null: MethodHandle private val sqlite3_bind_parameter_count: MethodHandle + private val sqlite3_bind_parameter_index: MethodHandle private val sqlite3_bind_text: MethodHandle private val sqlite3_bind_zeroblob: MethodHandle private val sqlite3_blob_bytes: MethodHandle @@ -687,267 +701,338 @@ internal class ExternalSQLite( loadLibrary(checkNotNull(ExternalSQLite::class.java.classLoader), "jni", "selekt") sqlite3_bind_blob = linker.downcallHandle( symbolLookup.find("sqlite3_bind_blob").orElseThrow(), - FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_INT, ADDRESS, JAVA_INT, ADDRESS) + FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_INT, ADDRESS, JAVA_INT, ADDRESS), + criticalOption ) sqlite3_bind_double = linker.downcallHandle( symbolLookup.find("sqlite3_bind_double").orElseThrow(), - FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_INT, JAVA_DOUBLE) + FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_INT, JAVA_DOUBLE), + criticalOption ) sqlite3_bind_int = linker.downcallHandle( symbolLookup.find("sqlite3_bind_int").orElseThrow(), - FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_INT, JAVA_INT) + FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_INT, JAVA_INT), + criticalOption ) sqlite3_bind_int64 = linker.downcallHandle( symbolLookup.find("sqlite3_bind_int64").orElseThrow(), - FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_INT, JAVA_LONG) + FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_INT, JAVA_LONG), + criticalOption ) sqlite3_bind_null = linker.downcallHandle( symbolLookup.find("sqlite3_bind_null").orElseThrow(), - FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_INT) + FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_INT), + criticalOption ) sqlite3_bind_parameter_count = linker.downcallHandle( symbolLookup.find("sqlite3_bind_parameter_count").orElseThrow(), - FunctionDescriptor.of(JAVA_INT, ADDRESS) + FunctionDescriptor.of(JAVA_INT, ADDRESS), + criticalOption + ) + sqlite3_bind_parameter_index = linker.downcallHandle( + symbolLookup.find("sqlite3_bind_parameter_index").orElseThrow(), + FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS), + criticalOption ) sqlite3_bind_text = linker.downcallHandle( symbolLookup.find("sqlite3_bind_text").orElseThrow(), - FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_INT, ADDRESS, JAVA_INT, ADDRESS) + FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_INT, ADDRESS, JAVA_INT, ADDRESS), + criticalOption ) sqlite3_bind_zeroblob = linker.downcallHandle( symbolLookup.find("sqlite3_bind_zeroblob").orElseThrow(), - FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_INT, JAVA_INT) + FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_INT, JAVA_INT), + criticalOption ) sqlite3_blob_bytes = linker.downcallHandle( symbolLookup.find("sqlite3_blob_bytes").orElseThrow(), - FunctionDescriptor.of(JAVA_INT, ADDRESS) + FunctionDescriptor.of(JAVA_INT, ADDRESS), + criticalOption ) sqlite3_blob_close = linker.downcallHandle( symbolLookup.find("sqlite3_blob_close").orElseThrow(), - FunctionDescriptor.of(JAVA_INT, ADDRESS) + FunctionDescriptor.of(JAVA_INT, ADDRESS), + criticalOption ) sqlite3_blob_open = linker.downcallHandle( symbolLookup.find("sqlite3_blob_open").orElseThrow(), - FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, ADDRESS, ADDRESS, JAVA_LONG, JAVA_INT, ADDRESS) + FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, ADDRESS, ADDRESS, JAVA_LONG, JAVA_INT, ADDRESS), + nonCriticalOption ) sqlite3_blob_read = linker.downcallHandle( symbolLookup.find("sqlite3_blob_read").orElseThrow(), - FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT, JAVA_INT) + FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT, JAVA_INT), + nonCriticalOption ) sqlite3_blob_reopen = linker.downcallHandle( symbolLookup.find("sqlite3_blob_reopen").orElseThrow(), - FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_LONG) + FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_LONG), + criticalOption ) sqlite3_blob_write = linker.downcallHandle( symbolLookup.find("sqlite3_blob_write").orElseThrow(), - FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT, JAVA_INT) + FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT, JAVA_INT), + nonCriticalOption ) sqlite3_busy_timeout = linker.downcallHandle( symbolLookup.find("sqlite3_busy_timeout").orElseThrow(), - FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_INT) + FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_INT), + criticalOption ) sqlite3_changes = linker.downcallHandle( symbolLookup.find("sqlite3_changes").orElseThrow(), - FunctionDescriptor.of(JAVA_INT, ADDRESS) + FunctionDescriptor.of(JAVA_INT, ADDRESS), + criticalOption ) sqlite3_clear_bindings = linker.downcallHandle( symbolLookup.find("sqlite3_clear_bindings").orElseThrow(), - FunctionDescriptor.of(JAVA_INT, ADDRESS) + FunctionDescriptor.of(JAVA_INT, ADDRESS), + criticalOption ) sqlite3_close_v2 = linker.downcallHandle( symbolLookup.find("sqlite3_close_v2").orElseThrow(), - FunctionDescriptor.of(JAVA_INT, ADDRESS) + FunctionDescriptor.of(JAVA_INT, ADDRESS), + nonCriticalOption ) sqlite3_column_blob = linker.downcallHandle( symbolLookup.find("sqlite3_column_blob").orElseThrow(), - FunctionDescriptor.of(ADDRESS, ADDRESS, JAVA_INT) + FunctionDescriptor.of(ADDRESS, ADDRESS, JAVA_INT), + criticalOption ) sqlite3_column_bytes = linker.downcallHandle( symbolLookup.find("sqlite3_column_bytes").orElseThrow(), - FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_INT) + FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_INT), + criticalOption ) sqlite3_column_count = linker.downcallHandle( symbolLookup.find("sqlite3_column_count").orElseThrow(), - FunctionDescriptor.of(JAVA_INT, ADDRESS) + FunctionDescriptor.of(JAVA_INT, ADDRESS), + criticalOption ) sqlite3_column_double = linker.downcallHandle( symbolLookup.find("sqlite3_column_double").orElseThrow(), - FunctionDescriptor.of(JAVA_DOUBLE, ADDRESS, JAVA_INT) + FunctionDescriptor.of(JAVA_DOUBLE, ADDRESS, JAVA_INT), + criticalOption ) sqlite3_column_int = linker.downcallHandle( symbolLookup.find("sqlite3_column_int").orElseThrow(), - FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_INT) + FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_INT), + criticalOption ) sqlite3_column_int64 = linker.downcallHandle( symbolLookup.find("sqlite3_column_int64").orElseThrow(), - FunctionDescriptor.of(JAVA_LONG, ADDRESS, JAVA_INT) + FunctionDescriptor.of(JAVA_LONG, ADDRESS, JAVA_INT), + criticalOption ) sqlite3_column_name = linker.downcallHandle( symbolLookup.find("sqlite3_column_name").orElseThrow(), - FunctionDescriptor.of(ADDRESS, ADDRESS, JAVA_INT) + FunctionDescriptor.of(ADDRESS, ADDRESS, JAVA_INT), + criticalOption ) sqlite3_column_text = linker.downcallHandle( symbolLookup.find("sqlite3_column_text").orElseThrow(), - FunctionDescriptor.of(ADDRESS, ADDRESS, JAVA_INT) + FunctionDescriptor.of(ADDRESS, ADDRESS, JAVA_INT), + criticalOption ) sqlite3_column_type = linker.downcallHandle( symbolLookup.find("sqlite3_column_type").orElseThrow(), - FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_INT) + FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_INT), + criticalOption ) sqlite3_column_value = linker.downcallHandle( symbolLookup.find("sqlite3_column_value").orElseThrow(), - FunctionDescriptor.of(ADDRESS, ADDRESS, JAVA_INT) + FunctionDescriptor.of(ADDRESS, ADDRESS, JAVA_INT), + criticalOption ) sqlite3_db_handle = linker.downcallHandle( symbolLookup.find("sqlite3_db_handle").orElseThrow(), - FunctionDescriptor.of(ADDRESS, ADDRESS) + FunctionDescriptor.of(ADDRESS, ADDRESS), + criticalOption ) sqlite3_db_readonly = linker.downcallHandle( symbolLookup.find("sqlite3_db_readonly").orElseThrow(), - FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS) + FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS), + criticalOption ) sqlite3_db_release_memory = linker.downcallHandle( symbolLookup.find("sqlite3_db_release_memory").orElseThrow(), - FunctionDescriptor.of(JAVA_INT, ADDRESS) + FunctionDescriptor.of(JAVA_INT, ADDRESS), + criticalOption ) sqlite3_db_status = linker.downcallHandle( symbolLookup.find("sqlite3_db_status").orElseThrow(), - FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_INT, ADDRESS, ADDRESS, JAVA_INT) + FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_INT, ADDRESS, ADDRESS, JAVA_INT), + criticalOption ) sqlite3_errcode = linker.downcallHandle( symbolLookup.find("sqlite3_errcode").orElseThrow(), - FunctionDescriptor.of(JAVA_INT, ADDRESS) + FunctionDescriptor.of(JAVA_INT, ADDRESS), + criticalOption ) sqlite3_errmsg = linker.downcallHandle( symbolLookup.find("sqlite3_errmsg").orElseThrow(), - FunctionDescriptor.of(ADDRESS, ADDRESS) + FunctionDescriptor.of(ADDRESS, ADDRESS), + criticalOption ) sqlite3_exec = linker.downcallHandle( symbolLookup.find("sqlite3_exec").orElseThrow(), - FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, ADDRESS, ADDRESS, ADDRESS) + FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, ADDRESS, ADDRESS, ADDRESS), + nonCriticalOption ) sqlite3_expanded_sql = linker.downcallHandle( symbolLookup.find("sqlite3_expanded_sql").orElseThrow(), - FunctionDescriptor.of(ADDRESS, ADDRESS) + FunctionDescriptor.of(ADDRESS, ADDRESS), + criticalOption ) sqlite3_extended_errcode = linker.downcallHandle( symbolLookup.find("sqlite3_extended_errcode").orElseThrow(), - FunctionDescriptor.of(JAVA_INT, ADDRESS) + FunctionDescriptor.of(JAVA_INT, ADDRESS), + criticalOption ) sqlite3_extended_result_codes = linker.downcallHandle( symbolLookup.find("sqlite3_extended_result_codes").orElseThrow(), - FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_INT) + FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_INT), + criticalOption ) sqlite3_finalize = linker.downcallHandle( symbolLookup.find("sqlite3_finalize").orElseThrow(), - FunctionDescriptor.of(JAVA_INT, ADDRESS) + FunctionDescriptor.of(JAVA_INT, ADDRESS), + criticalOption ) sqlite3_get_autocommit = linker.downcallHandle( symbolLookup.find("sqlite3_get_autocommit").orElseThrow(), - FunctionDescriptor.of(JAVA_INT, ADDRESS) + FunctionDescriptor.of(JAVA_INT, ADDRESS), + criticalOption ) sqlite3_hard_heap_limit64 = linker.downcallHandle( symbolLookup.find("sqlite3_hard_heap_limit64").orElseThrow(), - FunctionDescriptor.of(JAVA_LONG, JAVA_LONG) + FunctionDescriptor.of(JAVA_LONG, JAVA_LONG), + criticalOption ) sqlite3_key = linker.downcallHandle( symbolLookup.find("sqlite3_key").orElseThrow(), - FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT) + FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT), + criticalOption ) sqlite3_keyword_count = linker.downcallHandle( symbolLookup.find("sqlite3_keyword_count").orElseThrow(), - FunctionDescriptor.of(JAVA_INT) + FunctionDescriptor.of(JAVA_INT), + criticalOption ) sqlite3_last_insert_rowid = linker.downcallHandle( symbolLookup.find("sqlite3_last_insert_rowid").orElseThrow(), - FunctionDescriptor.of(JAVA_LONG, ADDRESS) + FunctionDescriptor.of(JAVA_LONG, ADDRESS), + criticalOption ) sqlite3_libversion = linker.downcallHandle( symbolLookup.find("sqlite3_libversion").orElseThrow(), - FunctionDescriptor.of(ADDRESS) + FunctionDescriptor.of(ADDRESS), + criticalOption ) sqlite3_libversion_number = linker.downcallHandle( symbolLookup.find("sqlite3_libversion_number").orElseThrow(), - FunctionDescriptor.of(JAVA_INT) + FunctionDescriptor.of(JAVA_INT), + criticalOption ) sqlite3_memory_used = linker.downcallHandle( symbolLookup.find("sqlite3_memory_used").orElseThrow(), - FunctionDescriptor.of(JAVA_LONG) + FunctionDescriptor.of(JAVA_LONG), + criticalOption ) sqlite3_open_v2 = linker.downcallHandle( symbolLookup.find("sqlite3_open_v2").orElseThrow(), - FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT, ADDRESS) + FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT, ADDRESS), + nonCriticalOption ) sqlite3_prepare_v2 = linker.downcallHandle( symbolLookup.find("sqlite3_prepare_v2").orElseThrow(), - FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT, ADDRESS, ADDRESS) + FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT, ADDRESS, ADDRESS), + criticalOption ) sqlite3_rekey = linker.downcallHandle( symbolLookup.find("sqlite3_rekey").orElseThrow(), - FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT) + FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT), + criticalOption ) sqlite3_release_memory = linker.downcallHandle( symbolLookup.find("sqlite3_release_memory").orElseThrow(), - FunctionDescriptor.of(JAVA_INT, JAVA_INT) + FunctionDescriptor.of(JAVA_INT, JAVA_INT), + criticalOption ) sqlite3_reset = linker.downcallHandle( symbolLookup.find("sqlite3_reset").orElseThrow(), - FunctionDescriptor.of(JAVA_INT, ADDRESS) + FunctionDescriptor.of(JAVA_INT, ADDRESS), + criticalOption ) sqlite3_sql = linker.downcallHandle( symbolLookup.find("sqlite3_sql").orElseThrow(), - FunctionDescriptor.of(ADDRESS, ADDRESS) + FunctionDescriptor.of(ADDRESS, ADDRESS), + criticalOption ) sqlite3_step = linker.downcallHandle( symbolLookup.find("sqlite3_step").orElseThrow(), - FunctionDescriptor.of(JAVA_INT, ADDRESS) + FunctionDescriptor.of(JAVA_INT, ADDRESS), + criticalOption ) sqlite3_stmt_busy = linker.downcallHandle( symbolLookup.find("sqlite3_stmt_busy").orElseThrow(), - FunctionDescriptor.of(JAVA_INT, ADDRESS) + FunctionDescriptor.of(JAVA_INT, ADDRESS), + criticalOption ) sqlite3_stmt_readonly = linker.downcallHandle( symbolLookup.find("sqlite3_stmt_readonly").orElseThrow(), - FunctionDescriptor.of(JAVA_INT, ADDRESS) + FunctionDescriptor.of(JAVA_INT, ADDRESS), + criticalOption ) sqlite3_stmt_status = linker.downcallHandle( symbolLookup.find("sqlite3_stmt_status").orElseThrow(), - FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_INT, JAVA_INT) + FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_INT, JAVA_INT), + criticalOption ) sqlite3_threadsafe = linker.downcallHandle( symbolLookup.find("sqlite3_threadsafe").orElseThrow(), - FunctionDescriptor.of(JAVA_INT) + FunctionDescriptor.of(JAVA_INT), + criticalOption ) sqlite3_total_changes = linker.downcallHandle( symbolLookup.find("sqlite3_total_changes").orElseThrow(), - FunctionDescriptor.of(JAVA_INT, ADDRESS) + FunctionDescriptor.of(JAVA_INT, ADDRESS), + criticalOption ) sqlite3_trace_v2 = linker.downcallHandle( symbolLookup.find("sqlite3_trace_v2").orElseThrow(), - FunctionDescriptor.ofVoid(ADDRESS, JAVA_INT, ADDRESS, ADDRESS) + FunctionDescriptor.ofVoid(ADDRESS, JAVA_INT, ADDRESS, ADDRESS), + criticalOption ) sqlite3_txn_state = linker.downcallHandle( symbolLookup.find("sqlite3_txn_state").orElseThrow(), - FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS) + FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS), + criticalOption ) sqlite3_value_dup = linker.downcallHandle( symbolLookup.find("sqlite3_value_dup").orElseThrow(), - FunctionDescriptor.of(ADDRESS, ADDRESS) + FunctionDescriptor.of(ADDRESS, ADDRESS), + criticalOption ) sqlite3_value_free = linker.downcallHandle( symbolLookup.find("sqlite3_value_free").orElseThrow(), - FunctionDescriptor.ofVoid(ADDRESS) + FunctionDescriptor.ofVoid(ADDRESS), + criticalOption ) sqlite3_value_frombind = linker.downcallHandle( symbolLookup.find("sqlite3_value_frombind").orElseThrow(), - FunctionDescriptor.of(JAVA_INT, ADDRESS) + FunctionDescriptor.of(JAVA_INT, ADDRESS), + criticalOption ) sqlite3_wal_autocheckpoint = linker.downcallHandle( symbolLookup.find("sqlite3_wal_autocheckpoint").orElseThrow(), - FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_INT) + FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_INT), + criticalOption ) sqlite3_wal_checkpoint_v2 = linker.downcallHandle( symbolLookup.find("sqlite3_wal_checkpoint_v2").orElseThrow(), - FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT, ADDRESS, ADDRESS) + FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT, ADDRESS, ADDRESS), + criticalOption ) } } diff --git a/selekt-sqlite3-sqlcipher/build.gradle.kts b/selekt-sqlite3-sqlcipher/build.gradle.kts index 549239adc9..546c6c7815 100644 --- a/selekt-sqlite3-sqlcipher/build.gradle.kts +++ b/selekt-sqlite3-sqlcipher/build.gradle.kts @@ -44,6 +44,9 @@ fun platformIdentifier() = "${osName()}-${System.getProperty("os.arch")}" tasks.named("jar") { archiveClassifier.set(platformIdentifier()) + metaInf { + from("$rootDir/SQLCIPHER_LICENSE") + } } tasks.withType().configureEach { diff --git a/settings.gradle.kts b/settings.gradle.kts index 7e6b659fd3..ec8924a7c2 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -28,6 +28,7 @@ include(":selekt-api") include(":selekt-bom") include(":selekt-commons") include(":selekt-java") +include(":selekt-jdbc") include(":selekt-jvm") include(":selekt-sqlite3-api") include(":selekt-sqlite3-classes")