diff --git a/jack-core/src/com/rapleaf/jack/DatabaseConnection.java b/jack-core/src/com/rapleaf/jack/DatabaseConnection.java index 4f373e64..722f03f1 100644 --- a/jack-core/src/com/rapleaf/jack/DatabaseConnection.java +++ b/jack-core/src/com/rapleaf/jack/DatabaseConnection.java @@ -14,27 +14,26 @@ // limitations under the License. package com.rapleaf.jack; +import static com.rapleaf.jack.DatabaseConnectionConstants.MYSQL_JDBC_DRIVER; +import static com.rapleaf.jack.DatabaseConnectionConstants.POSTGRESQL_JDBC_DRIVER; +import static com.rapleaf.jack.DatabaseConnectionConstants.REDSHIFT_JDBC_DRIVER; + import java.sql.Connection; import java.sql.DriverManager; import java.util.Collections; import java.util.HashMap; import java.util.Map; - -import com.google.common.base.Optional; - -import static com.rapleaf.jack.DatabaseConnectionConstants.MYSQL_JDBC_DRIVER; -import static com.rapleaf.jack.DatabaseConnectionConstants.POSTGRESQL_JDBC_DRIVER; -import static com.rapleaf.jack.DatabaseConnectionConstants.REDSHIFT_JDBC_DRIVER; +import java.util.Optional; +import java.util.stream.Collectors; /** - * The DatabaseConnection class manages connections to your databases. The - * database to be used is specified in config/environment.yml. This file - * in turn points to a set of login credentials in config/database.yml. + * The DatabaseConnection class manages connections to your databases. The database to be used is specified in + * config/environment.yml. This file in turn points to a set of login credentials in config/database.yml. *

- * All public methods methods of DatabaseConnection throw RuntimeExceptions - * (rather than IO or SQL exceptions). + * All public methods methods of DatabaseConnection throw RuntimeExceptions (rather than IO or SQL exceptions). */ public class DatabaseConnection extends BaseDatabaseConnection { + private static final String PARTITION_NUM_ENV_VARIABLE_NAME = "TLB_PARTITION_NUMBER"; private static final int DEFAULT_CONNECTION_MAX_RETRIES = 7; private static final int DEFAULT_VALIDATION_TIMEOUT_SECONDS = 3; @@ -72,6 +71,15 @@ public class DatabaseConnection extends BaseDatabaseConnection { private long expiration; public DatabaseConnection(String dbname_key, long expiration, String driverClass) { + this(dbname_key, expiration, driverClass, Collections.emptyMap()); + } + + public DatabaseConnection( + String dbname_key, + long expiration, + String driverClass, + Map additionalConnectionStringOptions + ) { DatabaseConnectionConfiguration config = DatabaseConnectionConfiguration.loadFromEnvironment(dbname_key); // get server credentials from database info String adapter = config.getAdapter(); @@ -80,7 +88,6 @@ public DatabaseConnection(String dbname_key, long expiration, String driverClass } this.driverClass = driverClass; - String driver = adapter; if (ADAPTER_TO_DRIVER.containsKey(adapter)) { driver = ADAPTER_TO_DRIVER.get(adapter); @@ -92,26 +99,34 @@ public DatabaseConnection(String dbname_key, long expiration, String driverClass connectionStringBuilder.append(":").append(config.getPort().get()); } connectionStringBuilder.append("/").append(getDbName(config.getDatabaseName(), config.enableParallelTests())); + if (!additionalConnectionStringOptions.isEmpty()) { + connectionStringBuilder.append("?"); + connectionStringBuilder.append( + additionalConnectionStringOptions.entrySet() + .stream() + .map(e -> String.join("=", e.getKey(), e.getValue())) + .collect(Collectors.joining("&")) + ); + } connectionString = connectionStringBuilder.toString(); username = config.getUsername(); password = config.getPassword(); connectionMaxRetries = config.getConnectionMaxRetries() - .or(DEFAULT_CONNECTION_MAX_RETRIES); + .orElse(DEFAULT_CONNECTION_MAX_RETRIES); validationTimeoutSeconds = config.getConnectionValidationTimeout() - .or(DEFAULT_VALIDATION_TIMEOUT_SECONDS); + .orElse(DEFAULT_VALIDATION_TIMEOUT_SECONDS); this.expiration = expiration; updateExpiration(); } /** - * Get a Connection to a database. If there is no connection, create a new one. - * If the connection hasn't been used in a long time, close it and create a new one. - * We do this because MySQL has an 8 hour idle connection timeout. + * Get a Connection to a database. If there is no connection, create a new one. If the connection hasn't been used in + * a long time, close it and create a new one. We do this because MySQL has an 8 hour idle connection timeout. *

- * Because of the intermittent downtime of SQL service, it implements exponential retry policy. - * The default retry policy retries seven times, which handles SQL downtime less than approx. two minutes. + * Because of the intermittent downtime of SQL service, it implements exponential retry policy. The default retry + * policy retries seven times, which handles SQL downtime less than approx. two minutes. */ public Connection getConnectionInternal() { @@ -119,7 +134,7 @@ public Connection getConnectionInternal() { try { if (conn == null) { Class.forName(driverClass); - conn = DriverManager.getConnection(connectionString, username.orNull(), password.orNull()); + conn = DriverManager.getConnection(connectionString, username.orElse(null), password.orElse(null)); } else if (isExpired() || !conn.isValid(validationTimeoutSeconds)) { resetConnection(); } @@ -128,7 +143,8 @@ public Connection getConnectionInternal() { } catch (Exception e) { //IOEx., SQLEx. // if it is the last retry, throw exception if (retryCount == connectionMaxRetries - 1) { - throw new RuntimeException(String.format("Could not establish connection after %d attempts", + throw new RuntimeException(String.format( + "Could not establish connection after %d attempts", connectionMaxRetries ), e); } @@ -140,19 +156,19 @@ public Connection getConnectionInternal() { } } } - throw new RuntimeException(String.format("Could not establish connection after %d attempts", + throw new RuntimeException(String.format( + "Could not establish connection after %d attempts", connectionMaxRetries )); } /** - * When using a parallel test environment, we append an integer that lives in - * an environment variable to the database name. + * When using a parallel test environment, we append an integer that lives in an environment variable to the database + * name. * * @param base_name the name of the database - * @param use_parallel if true, append an integer specified in an environment - * variable to the end of base_name + * @param use_parallel if true, append an integer specified in an environment variable to the end of base_name * @return the name of the database that we should connect to */ protected String getDbName(String base_name, Boolean use_parallel) { diff --git a/jack-core/src/com/rapleaf/jack/DatabaseConnectionConfiguration.java b/jack-core/src/com/rapleaf/jack/DatabaseConnectionConfiguration.java index f2c6fc0e..71984385 100644 --- a/jack-core/src/com/rapleaf/jack/DatabaseConnectionConfiguration.java +++ b/jack-core/src/com/rapleaf/jack/DatabaseConnectionConfiguration.java @@ -3,18 +3,17 @@ import java.io.FileReader; import java.io.Reader; import java.io.StringReader; +import java.util.HashMap; import java.util.Map; - -import com.google.common.base.Function; -import com.google.common.base.Optional; -import com.google.common.collect.Maps; +import java.util.Optional; +import java.util.function.Function; import org.jvyaml.YAML; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class DatabaseConnectionConfiguration { - private final static Logger LOG = LoggerFactory.getLogger(DatabaseConnectionConfiguration.class); + private static final Logger LOG = LoggerFactory.getLogger(DatabaseConnectionConfiguration.class); public static final String ADAPTER_PROP_PREFIX = "jack.db.adapter"; public static final String HOST_PROP_PREFIX = "jack.db.host"; @@ -41,7 +40,15 @@ public class DatabaseConnectionConfiguration { private Optional connectionMaxRetries; private Optional connectionValidationTimeout; - public DatabaseConnectionConfiguration(String adapter, String host, String dbName, Optional port, Optional parallelTest, Optional username, Optional password) { + public DatabaseConnectionConfiguration( + String adapter, + String host, + String dbName, + Optional port, + Optional parallelTest, + Optional username, + Optional password + ) { this( adapter, host, @@ -50,12 +57,22 @@ public DatabaseConnectionConfiguration(String adapter, String host, String dbNam parallelTest, username, password, - Optional.absent(), - Optional.absent() + Optional.empty(), + Optional.empty() ); } - public DatabaseConnectionConfiguration(String adapter, String host, String dbName, Optional port, Optional parallelTest, Optional username, Optional password, Optional connectionMaxRetries, Optional connectionValidationTimeout) { + public DatabaseConnectionConfiguration( + String adapter, + String host, + String dbName, + Optional port, + Optional parallelTest, + Optional username, + Optional password, + Optional connectionMaxRetries, + Optional connectionValidationTimeout + ) { this.adapter = adapter; this.host = host; this.dbName = dbName; @@ -73,49 +90,53 @@ public static DatabaseConnectionConfiguration loadFromEnvironment(String dbNameK // load database info from some file - first check env, then props, then default location envInfo = fetchInfoMap( "environment", - new EnvVarProvider(envVar(ENVIRONMENT_YML_PROP)), - new PropertyProvider(ENVIRONMENT_YML_PROP), - new FileReaderProvider(System.getenv(envVar(ENVIRONMENT_PATH_PROP))), - new FileReaderProvider(System.getProperty(ENVIRONMENT_PATH_PROP)), - new FileReaderProvider("config/environment.yml"), - new FileReaderProvider("environment.yml")); + envVarProvider(envVar(ENVIRONMENT_YML_PROP)), + propertyProvider(ENVIRONMENT_YML_PROP), + fileReaderProvider(System.getenv(envVar(ENVIRONMENT_PATH_PROP))), + fileReaderProvider(System.getProperty(ENVIRONMENT_PATH_PROP)), + fileReaderProvider("config/environment.yml"), + fileReaderProvider("environment.yml")); - String db_info_name = (String)envInfo.get(dbNameKey); + String db_info_name = (String) envInfo.get(dbNameKey); - dbInfo = (Map)fetchInfoMap( + dbInfo = (Map) fetchInfoMap( "database", - new EnvVarProvider(envVar(DATABASE_YML_PROP)), - new PropertyProvider(DATABASE_YML_PROP), - new FileReaderProvider(System.getenv(envVar(DATABASE_PATH_PROP))), - new FileReaderProvider(System.getProperty(DATABASE_PATH_PROP)), - new FileReaderProvider("config/database.yml"), - new FileReaderProvider("database.yml")).get(db_info_name); + envVarProvider(envVar(DATABASE_YML_PROP)), + propertyProvider(DATABASE_YML_PROP), + fileReaderProvider(System.getenv(envVar(DATABASE_PATH_PROP))), + fileReaderProvider(System.getProperty(DATABASE_PATH_PROP)), + fileReaderProvider("config/database.yml"), + fileReaderProvider("database.yml")).get(db_info_name); String adapter = load("adapter", dbInfo, "adapter", "database", - envVar(ADAPTER_PROP_PREFIX, dbNameKey), prop(ADAPTER_PROP_PREFIX, dbNameKey), new StringIdentity()); + envVar(ADAPTER_PROP_PREFIX, dbNameKey), prop(ADAPTER_PROP_PREFIX, dbNameKey), Function.identity()); String host = load("host", dbInfo, "host", "database", - envVar(HOST_PROP_PREFIX, dbNameKey), prop(HOST_PROP_PREFIX, dbNameKey), new StringIdentity()); + envVar(HOST_PROP_PREFIX, dbNameKey), prop(HOST_PROP_PREFIX, dbNameKey), Function.identity()); String dbName = load("database name", dbInfo, "database", "database", - envVar(NAME_PROP_PREFIX, dbNameKey), prop(NAME_PROP_PREFIX, dbNameKey), new StringIdentity()); + envVar(NAME_PROP_PREFIX, dbNameKey), prop(NAME_PROP_PREFIX, dbNameKey), Function.identity()); Optional port = loadOpt(dbInfo, "port", - envVar(PORT_PROP_PREFIX, dbNameKey), prop(PORT_PROP_PREFIX, dbNameKey), new ToInteger()); + envVar(PORT_PROP_PREFIX, dbNameKey), prop(PORT_PROP_PREFIX, dbNameKey), Integer::parseInt); - Optional parallelTesting = loadOpt(envInfo, "enable_parallel_tests", - envVar(PARALLEL_TEST_PROP_PREFIX, dbNameKey), prop(PARALLEL_TEST_PROP_PREFIX, dbNameKey), new ToBoolean()); + Optional parallelTesting = loadOpt( + envInfo, + "enable_parallel_tests", + envVar(PARALLEL_TEST_PROP_PREFIX, dbNameKey), + prop(PARALLEL_TEST_PROP_PREFIX, dbNameKey), + Boolean::parseBoolean); Optional username = loadOpt(dbInfo, "username", - envVar(USERNAME_PROP_PREFIX, dbNameKey), prop(USERNAME_PROP_PREFIX, dbNameKey), new StringIdentity()); + envVar(USERNAME_PROP_PREFIX, dbNameKey), prop(USERNAME_PROP_PREFIX, dbNameKey), Function.identity()); Optional password = loadOpt(dbInfo, "password", - envVar(PASSWORD_PROP_PREFIX, dbNameKey), prop(PASSWORD_PROP_PREFIX, dbNameKey), new StringIdentity()); + envVar(PASSWORD_PROP_PREFIX, dbNameKey), prop(PASSWORD_PROP_PREFIX, dbNameKey), Function.identity()); Optional connectionMaxRetriesLong = loadOpt(dbInfo, "connection_max_retries", - envVar(CONNECTION_MAX_RETRIES, dbNameKey), prop(CONNECTION_MAX_RETRIES, dbNameKey), new ToLong()); + envVar(CONNECTION_MAX_RETRIES, dbNameKey), prop(CONNECTION_MAX_RETRIES, dbNameKey), Long::parseLong); - Optional connectionMaxRetries = Optional.absent(); + Optional connectionMaxRetries = Optional.empty(); if (connectionMaxRetriesLong.isPresent()) { /** * This manual transformation is necessary because the underlying type parsed by @@ -128,19 +149,29 @@ public static DatabaseConnectionConfiguration loadFromEnvironment(String dbNameK connectionMaxRetries = Optional.of(connectionMaxRetriesLong.get().intValue()); } - Optional connectionValidationTimeoutLong = loadOpt(dbInfo, + Optional connectionValidationTimeoutLong = loadOpt( + dbInfo, "connection_validation_timeout", envVar(CONNECTION_VALIDATION_TIMEOUT, dbNameKey), prop(CONNECTION_VALIDATION_TIMEOUT, dbNameKey), - new ToLong() + Long::parseLong ); - Optional connectionValidationTimeout = Optional.absent(); + Optional connectionValidationTimeout = Optional.empty(); if (connectionValidationTimeoutLong.isPresent()) { connectionValidationTimeout = Optional.of(connectionValidationTimeoutLong.get().intValue()); } - return new DatabaseConnectionConfiguration(adapter, host, dbName, port, parallelTesting, username, password, connectionMaxRetries, connectionValidationTimeout); + return new DatabaseConnectionConfiguration( + adapter, + host, + dbName, + port, + parallelTesting, + username, + password, + connectionMaxRetries, + connectionValidationTimeout); } private static Map fetchInfoMap(String configName, ReaderProvider... readers) { @@ -148,17 +179,18 @@ private static Map fetchInfoMap(String configName, ReaderProvide try { Optional readerOptional = reader.get(); if (readerOptional.isPresent()) { - return (Map)YAML.load(readerOptional.get()); + return (Map) YAML.load(readerOptional.get()); } } catch (Exception e) { //move to next reader } } LOG.error("no yaml found for config: " + configName); - return Maps.newHashMap(); + return new HashMap<>(); } private interface ReaderProvider { + Optional get() throws Exception; } @@ -181,7 +213,8 @@ private static T load( String mapYmlFile, String envVar, String javaProp, - Function fromString) { + Function fromString + ) { Optional result = loadOpt(map, mapKey, envVar, javaProp, fromString); if (result.isPresent()) { @@ -201,17 +234,18 @@ private static Optional loadOpt( String mapKey, String envVar, String javaProp, - Function fromString) { + Function fromString + ) { if (System.getenv(envVar) != null) { - return Optional.fromNullable(fromString.apply(System.getenv(envVar))); + return Optional.ofNullable(fromString.apply(System.getenv(envVar))); } if (System.getProperty(javaProp) != null) { - return Optional.fromNullable(fromString.apply(System.getProperty(javaProp))); + return Optional.ofNullable(fromString.apply(System.getProperty(javaProp))); } if (map != null && map.containsKey(mapKey)) { - return Optional.fromNullable((T)map.get(mapKey)); + return Optional.ofNullable((T) map.get(mapKey)); } - return Optional.absent(); + return Optional.empty(); } @@ -232,7 +266,7 @@ public String getDatabaseName() { } public Boolean enableParallelTests() { - return parallelTest.isPresent() ? parallelTest.get() : false; + return parallelTest.orElse(false); } public Optional getUsername() { @@ -251,81 +285,24 @@ public Optional getConnectionValidationTimeout() { return connectionValidationTimeout; } - private static class StringIdentity implements Function { - public String apply(String s) { - return s; - } - } - - private static class ToInteger implements Function { - public Integer apply(String s) { - return Integer.parseInt(s); - } - } - - private static class ToLong implements Function { - public Long apply(String s) { - return Long.parseLong(s); - } - } - - private static class ToBoolean implements Function { - public Boolean apply(String s) { - return Boolean.parseBoolean(s); - } - } - - private static class FileReaderProvider implements ReaderProvider { - - private String file; - - public FileReaderProvider(String file) { - this.file = file; - } - - @Override - public Optional get() throws Exception { + // FileReader::new can throw an exception, so we can't use the Optional::map shorthand + private static ReaderProvider fileReaderProvider(String file) { + return () -> { if (file != null) { return Optional.of(new FileReader(file)); } else { - return Optional.absent(); + return Optional.empty(); } - } + }; } - private static class EnvVarProvider implements ReaderProvider { - - private String envVar; - - public EnvVarProvider(String envVar) { - this.envVar = envVar; - } - - @Override - public Optional get() throws Exception { - if (System.getenv(envVar) != null) { - return Optional.of(new StringReader(System.getenv(envVar))); - } else { - return Optional.absent(); - } - } + private static ReaderProvider envVarProvider(String envVar) { + return () -> Optional.ofNullable(System.getenv(envVar)) + .map(StringReader::new); } - private static class PropertyProvider implements ReaderProvider { - - private String property; - - public PropertyProvider(String property) { - this.property = property; - } - - @Override - public Optional get() throws Exception { - if (System.getProperty(property) != null) { - return Optional.of(new StringReader(System.getProperty(property))); - } else { - return Optional.absent(); - } - } + private static ReaderProvider propertyProvider(String property) { + return () -> Optional.ofNullable(System.getProperty(property)) + .map(StringReader::new); } } diff --git a/jack-mysql/src/com/rapleaf/jack/MysqlDatabaseConnection.java b/jack-mysql/src/com/rapleaf/jack/MysqlDatabaseConnection.java index 8bd39cfd..f8d1f5f5 100644 --- a/jack-mysql/src/com/rapleaf/jack/MysqlDatabaseConnection.java +++ b/jack-mysql/src/com/rapleaf/jack/MysqlDatabaseConnection.java @@ -3,12 +3,27 @@ import static com.rapleaf.jack.DatabaseConnectionConstants.DEFAULT_EXPIRATION; import static com.rapleaf.jack.DatabaseConnectionConstants.MYSQL_JDBC_DRIVER; +import java.util.HashMap; +import java.util.Map; + public class MysqlDatabaseConnection extends DatabaseConnection { + + private static final Map sslOptions = new HashMap<>(); + + static { + sslOptions.put("verifyServerCertificate", "false"); + sslOptions.put("useSSL", "true"); + } + public MysqlDatabaseConnection(String dbname_key) { this(dbname_key, DEFAULT_EXPIRATION); } public MysqlDatabaseConnection(String dbname_key, long expiration) { - super(dbname_key, expiration, MYSQL_JDBC_DRIVER); + super(dbname_key, expiration, MYSQL_JDBC_DRIVER, sslOptions); + } + + public MysqlDatabaseConnection(String dbname_key, long expiration, Map additionalOptions) { + super(dbname_key, expiration, MYSQL_JDBC_DRIVER, additionalOptions); } }