Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 42 additions & 26 deletions jack-core/src/com/rapleaf/jack/DatabaseConnection.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,26 @@
// limitations under the License.
package com.rapleaf.jack;

import static com.rapleaf.jack.DatabaseConnectionConstants.MYSQL_JDBC_DRIVER;
import static com.rapleaf.jack.DatabaseConnectionConstants.POSTGRESQL_JDBC_DRIVER;
import static com.rapleaf.jack.DatabaseConnectionConstants.REDSHIFT_JDBC_DRIVER;

import java.sql.Connection;
import java.sql.DriverManager;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

import com.google.common.base.Optional;

import static com.rapleaf.jack.DatabaseConnectionConstants.MYSQL_JDBC_DRIVER;
import static com.rapleaf.jack.DatabaseConnectionConstants.POSTGRESQL_JDBC_DRIVER;
import static com.rapleaf.jack.DatabaseConnectionConstants.REDSHIFT_JDBC_DRIVER;
import java.util.Optional;
import java.util.stream.Collectors;

/**
* The DatabaseConnection class manages connections to your databases. The
* database to be used is specified in config/environment.yml. This file
* in turn points to a set of login credentials in config/database.yml.
* The DatabaseConnection class manages connections to your databases. The database to be used is specified in
* config/environment.yml. This file in turn points to a set of login credentials in config/database.yml.
* <p/>
* All public methods methods of DatabaseConnection throw RuntimeExceptions
* (rather than IO or SQL exceptions).
* All public methods methods of DatabaseConnection throw RuntimeExceptions (rather than IO or SQL exceptions).
*/
public class DatabaseConnection extends BaseDatabaseConnection {

private static final String PARTITION_NUM_ENV_VARIABLE_NAME = "TLB_PARTITION_NUMBER";
private static final int DEFAULT_CONNECTION_MAX_RETRIES = 7;
private static final int DEFAULT_VALIDATION_TIMEOUT_SECONDS = 3;
Expand Down Expand Up @@ -72,6 +71,15 @@ public class DatabaseConnection extends BaseDatabaseConnection {
private long expiration;

public DatabaseConnection(String dbname_key, long expiration, String driverClass) {
this(dbname_key, expiration, driverClass, Collections.emptyMap());
}

public DatabaseConnection(
String dbname_key,
long expiration,
String driverClass,
Map<String, String> additionalConnectionStringOptions
) {
DatabaseConnectionConfiguration config = DatabaseConnectionConfiguration.loadFromEnvironment(dbname_key);
// get server credentials from database info
String adapter = config.getAdapter();
Expand All @@ -80,7 +88,6 @@ public DatabaseConnection(String dbname_key, long expiration, String driverClass
}
this.driverClass = driverClass;


String driver = adapter;
if (ADAPTER_TO_DRIVER.containsKey(adapter)) {
driver = ADAPTER_TO_DRIVER.get(adapter);
Expand All @@ -92,34 +99,42 @@ public DatabaseConnection(String dbname_key, long expiration, String driverClass
connectionStringBuilder.append(":").append(config.getPort().get());
}
connectionStringBuilder.append("/").append(getDbName(config.getDatabaseName(), config.enableParallelTests()));
if (!additionalConnectionStringOptions.isEmpty()) {
connectionStringBuilder.append("?");
connectionStringBuilder.append(
additionalConnectionStringOptions.entrySet()
.stream()
.map(e -> String.join("=", e.getKey(), e.getValue()))
.collect(Collectors.joining("&"))
);
}
connectionString = connectionStringBuilder.toString();
username = config.getUsername();
password = config.getPassword();

connectionMaxRetries = config.getConnectionMaxRetries()
.or(DEFAULT_CONNECTION_MAX_RETRIES);
.orElse(DEFAULT_CONNECTION_MAX_RETRIES);
validationTimeoutSeconds = config.getConnectionValidationTimeout()
.or(DEFAULT_VALIDATION_TIMEOUT_SECONDS);
.orElse(DEFAULT_VALIDATION_TIMEOUT_SECONDS);

this.expiration = expiration;
updateExpiration();
}

/**
* Get a Connection to a database. If there is no connection, create a new one.
* If the connection hasn't been used in a long time, close it and create a new one.
* We do this because MySQL has an 8 hour idle connection timeout.
* Get a Connection to a database. If there is no connection, create a new one. If the connection hasn't been used in
* a long time, close it and create a new one. We do this because MySQL has an 8 hour idle connection timeout.
* <p>
* Because of the intermittent downtime of SQL service, it implements exponential retry policy.
* The default retry policy retries seven times, which handles SQL downtime less than approx. two minutes.
* Because of the intermittent downtime of SQL service, it implements exponential retry policy. The default retry
* policy retries seven times, which handles SQL downtime less than approx. two minutes.
*/

public Connection getConnectionInternal() {
for (int retryCount = 0; retryCount < connectionMaxRetries; ++retryCount) {
try {
if (conn == null) {
Class.forName(driverClass);
conn = DriverManager.getConnection(connectionString, username.orNull(), password.orNull());
conn = DriverManager.getConnection(connectionString, username.orElse(null), password.orElse(null));
} else if (isExpired() || !conn.isValid(validationTimeoutSeconds)) {
resetConnection();
}
Expand All @@ -128,7 +143,8 @@ public Connection getConnectionInternal() {
} catch (Exception e) { //IOEx., SQLEx.
// if it is the last retry, throw exception
if (retryCount == connectionMaxRetries - 1) {
throw new RuntimeException(String.format("Could not establish connection after %d attempts",
throw new RuntimeException(String.format(
"Could not establish connection after %d attempts",
connectionMaxRetries
), e);
}
Expand All @@ -140,19 +156,19 @@ public Connection getConnectionInternal() {
}
}
}
throw new RuntimeException(String.format("Could not establish connection after %d attempts",
throw new RuntimeException(String.format(
"Could not establish connection after %d attempts",
connectionMaxRetries
));
}


/**
* When using a parallel test environment, we append an integer that lives in
* an environment variable to the database name.
* When using a parallel test environment, we append an integer that lives in an environment variable to the database
* name.
*
* @param base_name the name of the database
* @param use_parallel if true, append an integer specified in an environment
* variable to the end of base_name
* @param use_parallel if true, append an integer specified in an environment variable to the end of base_name
* @return the name of the database that we should connect to
*/
protected String getDbName(String base_name, Boolean use_parallel) {
Expand Down
Loading