Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions ossdbtoolsservice/object_explorer/object_explorer_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def register(self, service_provider: ServiceProvider):
def _handle_create_session_request(self, request_context: RequestContext, params: ConnectionDetails) -> None:
"""Handle a create object explorer session request"""
# Step 1: Create the session
session_exist_check = False
try:
# Make sure we have the appropriate session params
utils.validate.is_not_none('params', params)
Expand All @@ -94,7 +93,6 @@ def _handle_create_session_request(self, request_context: RequestContext, params
with self._session_lock:
if session_id in self._session_map:
# If session already exists, get it and respond with it
session_exist_check = True
if self._service_provider.logger is not None:
self._service_provider.logger.info(f'Object explorer session for {session_id} already exists. Returning existing session.')
session = self._session_map[session_id]
Expand All @@ -116,10 +114,9 @@ def _handle_create_session_request(self, request_context: RequestContext, params

# Step 2: Connect the session and lookup the root node asynchronously
try:
if not session_exist_check:
session.init_task = threading.Thread(target=self._initialize_session, args=(request_context, session))
session.init_task.daemon = True
session.init_task.start()
session.init_task = threading.Thread(target=self._initialize_session, args=(request_context, session))
session.init_task.daemon = True
session.init_task.start()
except Exception as e:
# TODO: Localize
self._session_created_error(request_context, session, f'Failed to start OE init task: {str(e)}')
Expand Down Expand Up @@ -354,16 +351,18 @@ def _generate_session_uri(params: ConnectionDetails, provider_name: str) -> str:
if provider_name == utils.constants.PG_PROVIDER_NAME:
utils.validate.is_not_none_or_whitespace('params.database_name', params.options.get('dbname'))
utils.validate.is_not_none('params.port', params.options.get('port'))
utils.validate.is_not_none_or_whitespace('params.groupId', params.options.get('groupId'))

# Generates a session ID that will function as the base URI for the session
host = quote(params.options['host'])
user = quote(params.options['user'])
db = quote(params.options['dbname'])
group_id = quote(params.options['groupId'].lower())
# Port number distinguishes between connections to different server
# instances with the same username, dbname running on same host
port = quote(str(params.options['port']))

return f'objectexplorer://{user}@{host}:{port}:{db}/'
return f'objectexplorer://{group_id}.{user}@{host}:{port}:{db}/'

def _route_request(self, is_refresh: bool, session: ObjectExplorerSession, path: str) -> List[NodeInfo]:
"""
Expand Down
127 changes: 95 additions & 32 deletions tests/object_explorer/test_object_explorer_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
TEST_USER = 'testuser'
TEST_PASSWORD = 'testpassword'
TEST_PORT = 5432
TEST_GROUP_ID = '8b66e74f-50f0-4f47-b525-880274ce3f98'


def _connection_details() -> Tuple[ConnectionDetails, str]:
Expand All @@ -44,7 +45,8 @@ def _connection_details() -> Tuple[ConnectionDetails, str]:
'host': TEST_HOST,
'dbname': TEST_DBNAME,
'user': TEST_USER,
'port': TEST_PORT
'port': TEST_PORT,
'groupId': TEST_GROUP_ID
}
session_uri = ObjectExplorerService._generate_session_uri(param, constants.PG_PROVIDER_NAME)
return param, session_uri
Expand Down Expand Up @@ -93,10 +95,11 @@ def test_register(self):
def test_generate_uri_missing_params(self):
# Setup: Create the parameter sets that will be missing a param each
params = [
ConnectionDetails.from_data({'host': None, 'dbname': TEST_DBNAME, 'user': TEST_USER, 'port': TEST_PORT}),
ConnectionDetails.from_data({'host': TEST_HOST, 'dbname': None, 'user': TEST_USER, 'port': TEST_PORT}),
ConnectionDetails.from_data({'host': TEST_HOST, 'dbname': TEST_DBNAME, 'user': None, 'port': TEST_PORT}),
ConnectionDetails.from_data({'host': TEST_HOST, 'dbname': TEST_DBNAME, 'user': TEST_USER, 'port': None})
ConnectionDetails.from_data({'host': None, 'dbname': TEST_DBNAME, 'user': TEST_USER, 'port': TEST_PORT, 'groupId': TEST_GROUP_ID}),
ConnectionDetails.from_data({'host': TEST_HOST, 'dbname': None, 'user': TEST_USER, 'port': TEST_PORT, 'groupId': TEST_GROUP_ID}),
ConnectionDetails.from_data({'host': TEST_HOST, 'dbname': TEST_DBNAME, 'user': None, 'port': TEST_PORT, 'groupId': TEST_GROUP_ID}),
ConnectionDetails.from_data({'host': TEST_HOST, 'dbname': TEST_DBNAME, 'user': TEST_USER, 'port': None, 'groupId': TEST_GROUP_ID}),
ConnectionDetails.from_data({'host': TEST_HOST, 'dbname': TEST_DBNAME, 'user': TEST_USER, 'port': TEST_PORT, 'groupId': None})
]

for param_set in params:
Expand All @@ -115,8 +118,11 @@ def test_generate_uri_valid_params(self):
self.assertEqual(parse_result.scheme, 'objectexplorer')
self.assertTrue(parse_result.netloc)

re_match = re.match(r'(?P<username>\w+)@(?P<host>\w+):(?P<port>\w+):(?P<db_name>\w+)', parse_result.netloc)
re_match = re.match(r'(?P<group_id>[0-9a-fA-F]{8}\-[0-9a-fA-F]{4}\-[0-9a-fA-F]{4}\-[0-9a-fA-F]{4}\-[0-9a-fA-F]{12})\.' +
r'(?P<username>\w+)@(?P<host>\w+):(?P<port>\w+):(?P<db_name>\w+)',
parse_result.netloc)
self.assertIsNotNone(re_match)
self.assertEqual(re_match.group('group_id'), TEST_GROUP_ID)
self.assertEqual(re_match.group('username'), TEST_USER)
self.assertEqual(re_match.group('host'), TEST_HOST)
self.assertEqual(int(re_match.group('port')), TEST_PORT)
Expand Down Expand Up @@ -171,20 +177,18 @@ def test_handle_create_session_session_exists(self):
session = ObjectExplorerSession(session_uri, params)
oe._session_map[session_uri] = session
oe._provider = constants.PG_PROVIDER_NAME
oe._server = Server

# If: I attempt to create an OE session that already exists
rc = RequestFlowValidator()
rc.add_expected_response(
CreateSessionResponse,
lambda param: self.assertEqual(param.session_id, session_uri)
)
rc = self._create_request_flow_validator(session_uri, oe)
oe._handle_create_session_request(rc.request_context, params)
oe._session_map[session_uri].init_task.join()

# Then:
# ... I should get a response as False
# ... I should get re-initialize the session and retrieve it
rc.validate()

# ... The old session should remain
# ... The old session id should still be there in the session map
self.assertIs(oe._session_map[session_uri], session)

def test_handle_create_session_threading_fail(self):
Expand Down Expand Up @@ -234,7 +238,68 @@ def test_handle_create_session_successful(self):
# ... Create parameters, session, request context validator
params, session_uri = _connection_details()

# ... Create validation of success notification
rc = self._create_request_flow_validator(session_uri, oe)

# If: I create a session
oe._handle_create_session_request(rc.request_context, params)
oe._session_map[session_uri].init_task.join()

# Then:
# ... Error notification should have been returned, session should be cleaned up from OE service
rc.validate()

# ... The session should still exist and should have connection and server setup
self.assertIn(session_uri, oe._session_map)
self.assertIsInstance(oe._session_map[session_uri].server, Server)
self.assertTrue(oe._session_map[session_uri].is_ready)

def test_handle_create_session_successful_multiple_sessions_different_groups(self):
# Setup:
# ... Create OE service with mock connection service that returns a successful connection response
mock_connection = MockPGServerConnection(cur=None, host='myserver', name='postgres', user='postgres', port=123)
cs = ConnectionService()
cs.connect = mock.MagicMock(return_value=ConnectionCompleteParams())
cs.get_connection = mock.MagicMock(return_value=mock_connection)
oe = ObjectExplorerService()
oe._service_provider = utils.get_mock_service_provider({constants.CONNECTION_SERVICE_NAME: cs})
oe._provider = constants.PG_PROVIDER_NAME
oe._server = Server

# ... Create parameters, session, request context validator
params_1 = ConnectionDetails.from_data({'host': TEST_HOST, 'dbname': TEST_DBNAME, 'user': TEST_USER, 'port': TEST_PORT, 'groupId': TEST_GROUP_ID})
session_uri_1 = ObjectExplorerService._generate_session_uri(params_1, constants.PG_PROVIDER_NAME)
TEST_GROUP_ID_2 = 'be804233-bfa3-496a-8ba3-bf6f6380280c'
params_2 = ConnectionDetails.from_data({'host': TEST_HOST, 'dbname': TEST_DBNAME, 'user': TEST_USER, 'port': TEST_PORT, 'groupId': TEST_GROUP_ID_2})
session_uri_2 = ObjectExplorerService._generate_session_uri(params_2, constants.PG_PROVIDER_NAME)

# Create the request flow validators
rc_1 = self._create_request_flow_validator(session_uri_1, oe)
rc_2 = self._create_request_flow_validator(session_uri_2, oe)

# If: I create a session
oe._handle_create_session_request(rc_1.request_context, params_1)
oe._session_map[session_uri_1].init_task.join()

# Then:
# ... Error notifications should have been returned, session should be cleaned up from OE service
rc_1.validate()

# Create the duplicate session and validate it as well
oe._handle_create_session_request(rc_2.request_context, params_2)
oe._session_map[session_uri_2].init_task.join()
rc_2.validate()

# ... The sessions should still exist and should have connection and server setup
self.assertIn(session_uri_1, oe._session_map)
self.assertIsInstance(oe._session_map[session_uri_1].server, Server)
self.assertTrue(oe._session_map[session_uri_1].is_ready)

self.assertIn(session_uri_2, oe._session_map)
self.assertIsInstance(oe._session_map[session_uri_2].server, Server)
self.assertTrue(oe._session_map[session_uri_2].is_ready)

# ... Create validation of success notification
def _generate_validate_success_notifcation_function(self, session_uri, oe: ObjectExplorerService):
def validate_success_notification(response: SessionCreatedParameters):
self.assertTrue(response.success)
self.assertEqual(response.session_id, session_uri)
Expand All @@ -249,7 +314,9 @@ def validate_success_notification(response: SessionCreatedParameters):
self.assertEqual(response.root_node.metadata.name, oe._session_map[session_uri].server.maintenance_db_name)
self.assertEqual(response.root_node.metadata.metadata_type_name, 'Database')
self.assertFalse(response.root_node.is_leaf)
return validate_success_notification

def _create_request_flow_validator(self, session_uri, oe: ObjectExplorerService) -> RequestFlowValidator:
rc = RequestFlowValidator()
rc.add_expected_response(
CreateSessionResponse,
Expand All @@ -258,21 +325,9 @@ def validate_success_notification(response: SessionCreatedParameters):
rc.add_expected_notification(
SessionCreatedParameters,
SESSION_CREATED_METHOD,
validate_success_notification
self._generate_validate_success_notifcation_function(session_uri, oe)
)

# If: I create a session
oe._handle_create_session_request(rc.request_context, params)
oe._session_map[session_uri].init_task.join()

# Then:
# ... Error notification should have been returned, session should be cleaned up from OE service
rc.validate()

# ... The session should still exist and should have connection and server setup
self.assertIn(session_uri, oe._session_map)
self.assertIsInstance(oe._session_map[session_uri].server, Server)
self.assertTrue(oe._session_map[session_uri].is_ready)
return rc

def test_init_session_cancelled_connection(self):
# Setup:
Expand Down Expand Up @@ -700,7 +755,9 @@ def test_handle_close_session_unsuccessful(self):

# Then: I should get a successful response
rc.validate()
self.oe._service_provider.logger.info.assert_called_with('Could not close the OE session with Id objectexplorer://testuser@testhost:5432:testdb/')
self.oe._service_provider.logger.info.assert_called_with(
'Could not close the OE session with Id objectexplorer://8b66e74f-50f0-4f47-b525-880274ce3f98.testuser@testhost:5432:testdb/'
)

def test_handle_close_session_throwsException(self):
# setup to throw exception on disconnect
Expand Down Expand Up @@ -738,23 +795,29 @@ def test_handle_close_session_successful(self):
def test_handle_shutdown_successfulWithSessions(self):
# shutdown the session
self.oe._handle_shutdown()
self.oe._service_provider.logger.info.assert_called_with('Closed the OE session with Id: objectexplorer://testuser@testhost:5432:testdb/')
self.oe._service_provider.logger.info.assert_called_with(
'Closed the OE session with Id: objectexplorer://8b66e74f-50f0-4f47-b525-880274ce3f98.testuser@testhost:5432:testdb/'
)

def test_handle_shutdown_successfulNoDatabase(self):
# Setup: Create an OE service and add a session to it
self.session.server._child_objects[Database.__name__] = []

# shutdown the session
self.oe._handle_shutdown()
self.oe._service_provider.logger.info.assert_called_with('Closed the OE session with Id: objectexplorer://testuser@testhost:5432:testdb/')
self.oe._service_provider.logger.info.assert_called_with(
'Closed the OE session with Id: objectexplorer://8b66e74f-50f0-4f47-b525-880274ce3f98.testuser@testhost:5432:testdb/'
)

def test_handle_shutdown_UnsuccessfulWithSessions(self):
# Setup: Create an OE service and add a session to it
self.cs.disconnect = mock.MagicMock(return_value=False)

# shutdown the session
self.oe._handle_shutdown()
self.oe._service_provider.logger.info.assert_called_with('Could not close the OE session with Id: objectexplorer://testuser@testhost:5432:testdb/')
self.oe._service_provider.logger.info.assert_called_with(
'Could not close the OE session with Id: objectexplorer://8b66e74f-50f0-4f47-b525-880274ce3f98.testuser@testhost:5432:testdb/'
)

def test_handle_shutdown_successfulNoSessions(self):
# Setup: Create an empty session dictionary
Expand Down