diff --git a/config.go b/config.go index 472a362..a05afad 100644 --- a/config.go +++ b/config.go @@ -4,8 +4,11 @@ import ( "log" "os" "strconv" + "strings" ) +const defaultSSLMode = "disable" + // PoolMaxConnLifetime and PoolMaxConnIdle are string time duration representations // as defined in ParseDuration in the stdlib time package // the format consists of decimal numbers, each with optional fraction and a unit suffix, @@ -21,6 +24,7 @@ type RdbmsConfig struct { OnInit string DbDriver string DbStore string + DbSSLMode string PoolMaxConns int PoolMinConns int @@ -30,6 +34,15 @@ type RdbmsConfig struct { DbDriverSettings string } +var sslModeMap = map[string]string{ + "disable": "disable", + "allow": "allow", + "prefer": "prefer", + "require": "require", + "verify-ca": "verify-ca", + "verify-full": "verify-full", +} + func RdbmsConfigFromEnv() *RdbmsConfig { dbConfig := new(RdbmsConfig) dbConfig.Dbuser = os.Getenv("DBUSER") @@ -40,11 +53,23 @@ func RdbmsConfigFromEnv() *RdbmsConfig { dbConfig.DbDriver = os.Getenv("DBDRIVER") dbConfig.DbStore = os.Getenv("DBSTORE") dbConfig.ExternalLib = os.Getenv("EXTERNAL_LIB") + dbConfig.DbSSLMode = os.Getenv("DBSSLMODE") if dbConfig.Dbport == "" { dbConfig.Dbport = "5432" } + if dbConfig.DbSSLMode == "" { + dbConfig.DbSSLMode = defaultSSLMode + } else { + if sslMode, ok := sslModeMap[strings.ToLower(dbConfig.DbSSLMode)]; ok { + dbConfig.DbSSLMode = sslMode + } else { + log.Printf("Error parsing DBSSLMODE value of \"%s\": Will fall back to default DBSSLMODE value.\n", dbConfig.DbSSLMode) + dbConfig.DbSSLMode = defaultSSLMode + } + } + maxConns := os.Getenv("POOLMAXCONNS") mc, err := strconv.Atoi(maxConns) if err != nil { diff --git a/pg_dialect.go b/pg_dialect.go index 4befa8e..6ed6714 100644 --- a/pg_dialect.go +++ b/pg_dialect.go @@ -1,6 +1,9 @@ package goquery -import "fmt" +import ( + "fmt" + "log" +) var pgDialect = DbDialect{ TableExistsStmt: `SELECT count(*) FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2`, @@ -11,7 +14,11 @@ var pgDialect = DbDialect{ return fmt.Sprintf("nextval('%s')", sequence) }, Url: func(config *RdbmsConfig) string { - return fmt.Sprintf("user=%s password=%s host=%s port=%s database=%s sslmode=disable", - config.Dbuser, config.Dbpass, config.Dbhost, config.Dbport, config.Dbname) + if config.DbSSLMode == "" { + config.DbSSLMode = defaultSSLMode + log.Printf("No sslmode set, will fall back to default DBSSLMODE value %s. Set value in the dbconfig using DbSSLMode \n", defaultSSLMode) + } + return fmt.Sprintf("user=%s password=%s host=%s port=%s database=%s sslmode=%s", + config.Dbuser, config.Dbpass, config.Dbhost, config.Dbport, config.Dbname, config.DbSSLMode) }, } diff --git a/pgx_db.go b/pgx_db.go index 32038ab..21e16da 100644 --- a/pgx_db.go +++ b/pgx_db.go @@ -3,6 +3,7 @@ package goquery import ( "context" "fmt" + "log" "reflect" "time" @@ -107,8 +108,12 @@ type PgxDb struct { } func NewPgxConnection(config *RdbmsConfig) (PgxDb, error) { - dburl := fmt.Sprintf("user=%s password=%s host=%s port=%s database=%s sslmode=disable", - config.Dbuser, config.Dbpass, config.Dbhost, config.Dbport, config.Dbname) + if config.DbSSLMode == "" { + config.DbSSLMode = defaultSSLMode + log.Printf("No sslmode set, will fall back to default DBSSLMODE value %s. Set value in the dbconfig using DbSSLMode \n", defaultSSLMode) + } + dburl := fmt.Sprintf("user=%s password=%s host=%s port=%s database=%s sslmode=%s", + config.Dbuser, config.Dbpass, config.Dbhost, config.Dbport, config.Dbname, config.DbSSLMode) if config.PoolMaxConns > 0 { dburl = fmt.Sprintf("%s %s=%d", dburl, "pool_max_conns", config.PoolMaxConns)