diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index b6d3575210..60799f2fe8 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -247,6 +247,7 @@ def __init__( ssl_check_hostname: bool = True, ssl_min_version: Optional[TLSVersion] = None, ssl_ciphers: Optional[str] = None, + ssl_password: Optional[str] = None, max_connections: Optional[int] = None, single_connection_client: bool = False, health_check_interval: int = 0, @@ -359,6 +360,7 @@ def __init__( "ssl_check_hostname": ssl_check_hostname, "ssl_min_version": ssl_min_version, "ssl_ciphers": ssl_ciphers, + "ssl_password": ssl_password, } ) # This arg only used if no pool is passed in diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index d44c8da7e4..c13b921e2b 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -817,6 +817,7 @@ def __init__( ssl_check_hostname: bool = True, ssl_min_version: Optional[TLSVersion] = None, ssl_ciphers: Optional[str] = None, + ssl_password: Optional[str] = None, **kwargs, ): if not SSL_AVAILABLE: @@ -834,6 +835,7 @@ def __init__( check_hostname=ssl_check_hostname, min_version=ssl_min_version, ciphers=ssl_ciphers, + password=ssl_password, ) super().__init__(**kwargs) @@ -893,6 +895,7 @@ class RedisSSLContext: "check_hostname", "min_version", "ciphers", + "password", ) def __init__( @@ -908,6 +911,7 @@ def __init__( check_hostname: bool = False, min_version: Optional[TLSVersion] = None, ciphers: Optional[str] = None, + password: Optional[str] = None, ): if not SSL_AVAILABLE: raise RedisError("Python wasn't built with SSL support") @@ -938,6 +942,7 @@ def __init__( ) self.min_version = min_version self.ciphers = ciphers + self.password = password self.context: Optional[SSLContext] = None def get(self) -> SSLContext: @@ -951,8 +956,12 @@ def get(self) -> SSLContext: if self.exclude_verify_flags: for flag in self.exclude_verify_flags: context.verify_flags &= ~flag - if self.certfile and self.keyfile: - context.load_cert_chain(certfile=self.certfile, keyfile=self.keyfile) + if self.certfile or self.keyfile: + context.load_cert_chain( + certfile=self.certfile, + keyfile=self.keyfile, + password=self.password, + ) if self.ca_certs or self.ca_data or self.ca_path: context.load_verify_locations( cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data diff --git a/tests/test_asyncio/test_ssl.py b/tests/test_asyncio/test_ssl.py index 54a8e2f28c..c5cbb90cfe 100644 --- a/tests/test_asyncio/test_ssl.py +++ b/tests/test_asyncio/test_ssl.py @@ -167,3 +167,29 @@ async def test_ssl_ca_path_parameter(self, request): assert conn.ssl_context.ca_path == test_ca_path finally: await r.aclose() + + async def test_ssl_password_parameter(self, request): + """Test that ssl_password parameter is properly passed to SSLConnection""" + ssl_url = request.config.option.redis_ssl_url + parsed_url = urlparse(ssl_url) + + # Test with a mock password for encrypted private key + test_password = "test_key_password" + + r = redis.Redis( + host=parsed_url.hostname, + port=parsed_url.port, + ssl=True, + ssl_cert_reqs="none", + ssl_password=test_password, + ) + + try: + # Get the connection to verify ssl_password is passed through + conn = r.connection_pool.make_connection() + assert isinstance(conn, redis.SSLConnection) + + # Verify the password is stored in the SSL context + assert conn.ssl_context.password == test_password + finally: + await r.aclose()