diff --git a/ossdbtoolsservice/driver/types/driver.py b/ossdbtoolsservice/driver/types/driver.py index 5bd2f680d..2f73c8470 100644 --- a/ossdbtoolsservice/driver/types/driver.py +++ b/ossdbtoolsservice/driver/types/driver.py @@ -61,6 +61,11 @@ def default_database(self) -> str: def database_error(self) -> Exception: """ Returns the type of database error this connection throws""" + @property + @abstractmethod + def transaction_is_active(self) -> bool: + """Returns bool indicating if transaction is active""" + @property @abstractmethod def transaction_in_error(self) -> bool: @@ -71,6 +76,11 @@ def transaction_in_error(self) -> bool: def transaction_is_idle(self) -> bool: """Returns bool indicating if transaction is currently idle""" + @property + @abstractmethod + def transaction_in_unknown(self) -> bool: + """Returns bool indicating if transaction is active""" + @property @abstractmethod def transaction_in_trans(self) -> bool: @@ -179,6 +189,12 @@ def close(self): Closes this current connection. """ + + def transaction_status(self): + """ + Gets the current transaction status if it exists + """ + @abstractmethod def set_transaction_in_error(self): """ diff --git a/ossdbtoolsservice/driver/types/psycopg_driver.py b/ossdbtoolsservice/driver/types/psycopg_driver.py index 8cd2711a9..13c72c70b 100644 --- a/ossdbtoolsservice/driver/types/psycopg_driver.py +++ b/ossdbtoolsservice/driver/types/psycopg_driver.py @@ -150,6 +150,11 @@ def database_error(self): """Returns the type of database error this connection throws""" return self._database_error + @property + def transaction_is_active(self) -> bool: + """Returns bool indicating if transaction is active""" + return self._conn.info.transaction_status is TransactionStatus.ACTIVE + @property def transaction_in_error(self) -> bool: """Returns bool indicating if transaction is in error""" @@ -159,6 +164,11 @@ def transaction_in_error(self) -> bool: def transaction_is_idle(self) -> bool: """Returns bool indicating if transaction is currently idle""" return self._conn.info.transaction_status is TransactionStatus.IDLE + + @property + def transaction_in_unknown(self) -> bool: + """Returns bool indicating if transaction is in unknown state""" + return self._conn.info.transaction_status is TransactionStatus.UNKNOWN @property def transaction_in_trans(self) -> bool: @@ -315,6 +325,14 @@ def close(self): """ self._conn.close() + def transaction_status(self): + """ + Gets the current transaction status if it exists + """ + if self._conn and self._conn.info: + return self._conn.info.transaction_status + return None + def set_transaction_in_error(self): """ Sets if current connection is in error diff --git a/ossdbtoolsservice/query/batch.py b/ossdbtoolsservice/query/batch.py index c30dd17db..96cfa8e58 100644 --- a/ossdbtoolsservice/query/batch.py +++ b/ossdbtoolsservice/query/batch.py @@ -107,50 +107,6 @@ def notices(self) -> List[str]: def get_cursor(self, connection: ServerConnection): return connection.cursor() - def execute(self, conn: ServerConnection) -> None: - """ - Execute the batch using a cursor retrieved from the given connection - - :raises DatabaseError: if an error is encountered while running the batch's query - """ - self._execution_start_time = datetime.now() - - if self._batch_events and self._batch_events._on_execution_started: - self._batch_events._on_execution_started(self) - - cursor = self.get_cursor(conn) - - conn.connection.add_notice_handler(lambda msg: self.notice_handler(msg, conn)) - - if self.batch_text.startswith('begin') and conn.transaction_in_trans: - self._notices.append('WARNING: there is already a transaction in progress') - - try: - cursor.execute(self.batch_text) - - # Commit the transaction if autocommit is True - if conn.autocommit: - conn.commit() - - self.after_execute(cursor) - except psycopg.DatabaseError as e: - self._has_error = True - conn.set_transaction_in_error() - raise e - finally: - if cursor and cursor.statusmessage is not None: - self.status_message = cursor.statusmessage - # We are doing this because when the execute fails for named cursors - # cursor is not activated on the server which results in failure on close - # Hence we are checking if the cursor was really executed for us to close it - if cursor and cursor.rowcount != -1 and cursor.rowcount is not None: - cursor.close() - self._has_executed = True - self._execution_end_time = datetime.now() - - if self._batch_events and self._batch_events._on_execution_completed: - self._batch_events._on_execution_completed(self) - def after_execute(self, cursor) -> None: if cursor.description is not None: self.create_result_set(cursor) @@ -201,14 +157,15 @@ def create_result_set(storage_type: ResultSetStorageType, result_set_id: int, ba return InMemoryResultSet(result_set_id, batch_id) -def create_batch(batch_text: str, ordinal: int, selection: SelectionData, batch_events: BatchEvents, storage_type: ResultSetStorageType) -> Batch: - sql = sqlparse.parse(batch_text) - statement = sql[0] - - if statement.get_type().lower() == 'select': - into_checker = [True for token in statement.tokens if token.normalized == 'INTO'] - cte_checker = [True for token in statement.tokens if token.ttype == sqlparse.tokens.Keyword.CTE] - if len(into_checker) == 0 and len(cte_checker) == 0: # SELECT INTO and CTE keywords can't be used in named cursor - return SelectBatch(batch_text, ordinal, selection, batch_events, storage_type) +def create_batch(batch_text: str, ordinal: int, selection: SelectionData, batch_events: BatchEvents, storage_type: ResultSetStorageType, select_batch: bool = False) -> Batch: + # sql = sqlparse.parse(batch_text) + # statement = sql[0] + # if statement.get_type().lower() == 'select': + # into_checker = [True for token in statement.tokens if token.normalized == 'INTO'] + # cte_checker = [True for token in statement.tokens if token.ttype == sqlparse.tokens.Keyword.CTE] + # if len(into_checker) == 0 and len(cte_checker) == 0: # SELECT INTO and CTE keywords can't be used in named cursor + # return SelectBatch(batch_text, ordinal, selection, batch_events, storage_type) + if select_batch: + return SelectBatch(batch_text, ordinal, selection, batch_events, storage_type) return Batch(batch_text, ordinal, selection, batch_events, storage_type) diff --git a/ossdbtoolsservice/query/query.py b/ossdbtoolsservice/query/query.py index 941ac0bae..ff39386af 100644 --- a/ossdbtoolsservice/query/query.py +++ b/ossdbtoolsservice/query/query.py @@ -3,14 +3,17 @@ # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- +from datetime import datetime from enum import Enum -from typing import Callable, Dict, List, Optional # noqa +from typing import Callable, Dict, List, Optional, Tuple # noqa import sqlparse from ossdbtoolsservice.driver import ServerConnection from ossdbtoolsservice.query import Batch, BatchEvents, create_batch, ResultSetStorageType from ossdbtoolsservice.query.contracts import SaveResultsRequestParams, SelectionData from ossdbtoolsservice.query.data_storage import FileStreamFactory +import psycopg +from utils import constants class QueryEvents: @@ -59,43 +62,27 @@ def __init__(self, owner_uri: str, query_text: str, query_execution_settings: Qu self._user_transaction = False self._current_batch_index = 0 self._batches: List[Batch] = [] + self._notices: List[str] = [] self._execution_plan_options = query_execution_settings.execution_plan_options + self._query_events = query_events + self._query_execution_settings = query_execution_settings self.is_canceled = False - - # Initialize the batches - statements = sqlparse.split(query_text) - selection_data = compute_selection_data_for_batches(statements, query_text) - - for index, batch_text in enumerate(statements): - # Skip any empty text - formatted_text = sqlparse.format(batch_text, strip_comments=True).strip() - if not formatted_text or formatted_text == ';': - continue - - sql_statement_text = batch_text - - # Create and save the batch - if bool(self._execution_plan_options): - if self._execution_plan_options.include_estimated_execution_plan_xml: - sql_statement_text = Query.EXPLAIN_QUERY_TEMPLATE.format(sql_statement_text) - elif self._execution_plan_options.include_actual_execution_plan_xml: - self._disable_auto_commit = True - sql_statement_text = Query.ANALYZE_EXPLAIN_QUERY_TEMPLATE.format(sql_statement_text) - - # Check if user defined transaction - if formatted_text.lower().startswith('begin'): + # Use the same selection data for all batches. We want to avoid parsing and splitting into separate SQL statements + self.selection_data = compute_selection_data_for_batches([self.query_text], self.query_text)[0] + + # # Create and save the batch + if bool(self._execution_plan_options): + if self._execution_plan_options.include_estimated_execution_plan_xml: + sql_statement_text = Query.EXPLAIN_QUERY_TEMPLATE.format(sql_statement_text) + elif self._execution_plan_options.include_actual_execution_plan_xml: self._disable_auto_commit = True - self._user_transaction = True - - batch = create_batch( - sql_statement_text, - len(self.batches), - selection_data[index], - query_events.batch_events, - query_execution_settings.result_set_storage_type) + sql_statement_text = Query.ANALYZE_EXPLAIN_QUERY_TEMPLATE.format(sql_statement_text) - self._batches.append(batch) + # Check if user defined transaction + if self.query_text.lower().startswith('begin'): + self._disable_auto_commit = True + self._user_transaction = True @property def owner_uri(self) -> str: @@ -116,6 +103,14 @@ def batches(self) -> List[Batch]: @property def current_batch_index(self) -> int: return self._current_batch_index + + @property + def query_events(self) -> QueryEvents: + return self._query_events + + @property + def query_execution_settings(self) -> QueryExecutionSettings: + return self._query_execution_settings def execute(self, connection: ServerConnection, retry_state=False): """ @@ -140,13 +135,56 @@ def execute(self, connection: ServerConnection, retry_state=False): if self._disable_auto_commit and connection.transaction_is_idle: connection.autocommit = False - for batch_index, batch in enumerate(self._batches): - self._current_batch_index = batch_index - - if self.is_canceled: - break + # Start a cursor block + batch_events: BatchEvents = None + if self.query_events is not None and self.query_events.batch_events is not None: + batch_events = self.query_events.batch_events + + connection.connection.add_notice_handler(lambda msg: self.notice_handler(msg, connection)) + with connection.cursor() as cur: + start_time = datetime.now() + + try: + if self.is_canceled: + return + cur.execute(self.query_text) + end_time = datetime.now() + except psycopg.DatabaseError as e: + end_time = datetime.now() + self.handle_database_error_during_execute(connection, (start_time, end_time), batch_events) + # Exit + raise e + + curr_resultset = True + while curr_resultset and len(self.batches) <= constants.MAX_BATCH_RESULT_MESSAGES: + # Break if canceled + if self.is_canceled: + break + + # Create and append a new batch object + batch_obj = self.create_next_batch(self.current_batch_index, (start_time, end_time), batch_events) + + # Create the result set if necessary and set to _has_executed + batch_obj.after_execute(cur) + batch_obj._has_executed = True + + if cur and cur.statusmessage is not None: + batch_obj.status_message = cur.statusmessage + + # Update while loop values + curr_resultset = cur.nextset() + self._current_batch_index += 1 + + # Call Completed callback + if batch_events and batch_events._on_execution_completed: + if not curr_resultset or len(self.batches) >= constants.MAX_BATCH_RESULT_MESSAGES: + batch_obj._notices = self._notices + batch_obj.notices.append(f"WARNING: This query has reached the max limit of {constants.MAX_BATCH_RESULT_MESSAGES} results. The rest of the query has been executed, but furthter results will not be shown") + batch_events._on_execution_completed(batch_obj) + break + else: + batch_events._on_execution_completed(batch_obj) - batch.execute(connection) finally: # We can only set autocommit when the connection is open. if connection.open and connection.transaction_is_idle: @@ -155,6 +193,42 @@ def execute(self, connection: ServerConnection, retry_state=False): self._disable_auto_commit = False self._execution_state = ExecutionState.EXECUTED + def create_next_batch(self, ordinal: int, execution_times: Tuple[datetime, datetime], batch_events: BatchEvents, empty_selection_data = False): + start_time, end_time = execution_times + batch_obj = create_batch( + self.query_text, + self.current_batch_index, + self.selection_data, + self.query_events.batch_events, + self.query_execution_settings.result_set_storage_type + ) + self.batches.append(batch_obj) + + # Only set end execution time to first batch summary as we cannot collect individual statement execution times + batch_obj._execution_start_time = start_time + if self.current_batch_index == 0: + batch_obj._execution_end_time = end_time + + # Call start callback + if batch_events and batch_events._on_execution_started: + batch_events._on_execution_started(batch_obj) + return batch_obj + + def handle_database_error_during_execute(self, conn: ServerConnection, execution_times: Tuple[datetime, datetime], batch_events: BatchEvents): + batch_obj = self.create_next_batch(0, execution_times, batch_events) + + batch_obj._has_error = True + self.batches.append(batch_obj) + self._current_batch_index = 0 + conn.set_transaction_in_error() + + def notice_handler(self, notice: psycopg.errors.Diagnostic, conn: ServerConnection): + # Add notices to last batch element + if not conn.user_transaction: + self._notices.append('{0}: {1}'.format(notice.severity, notice.message_primary)) + elif not notice.message_primary == 'there is already a transaction in progress': + self._notices.append('WARNING: {0}'.format(notice.message_primary)) + def get_subset(self, batch_index: int, start_index: int, end_index: int): if batch_index < 0 or batch_index >= len(self._batches): raise IndexError('Batch index cannot be less than 0 or greater than the number of batches') @@ -180,16 +254,26 @@ def compute_selection_data_for_batches(batches: List[str], full_text: str) -> Li # Iterate through the batches to build selection data selection_data: List[SelectionData] = [] search_offset = 0 + line_map_keys = sorted(line_map.keys()) + l, r = 0, 0 + start_line_index, end_line_index = 0, 0 for batch in batches: # Calculate the starting line number and column - start_index = full_text.index(batch, search_offset) - start_line_index = max(filter(lambda line_index: line_index <= start_index, line_map.keys())) - start_line_num = line_map[start_line_index] - start_col_num = start_index - start_line_index + start_index = full_text.index(batch, search_offset) # batch start index + # start_line_index = max(filter(lambda line_index: line_index <= start_index, line_map_keys)) # find the character index of the batch start line + while l < len(line_map_keys) and line_map_keys[l] <= start_index: + start_line_index = line_map_keys[l] + l += 1 + + start_line_num = line_map[start_line_index] # map that to the line number + start_col_num = start_index - start_line_index # Calculate the ending line number and column end_index = start_index + len(batch) - end_line_index = max(filter(lambda line_index: line_index < end_index, line_map.keys())) + # end_line_index = max(filter(lambda line_index: line_index < end_index, line_map_keys)) + while r < len(line_map_keys) and line_map_keys[r] < end_index: + end_line_index = line_map_keys[r] + r += 1 end_line_num = line_map[end_line_index] end_col_num = end_index - end_line_index diff --git a/ossdbtoolsservice/query_execution/query_execution_service.py b/ossdbtoolsservice/query_execution/query_execution_service.py index 019c586f5..6fa353847 100644 --- a/ossdbtoolsservice/query_execution/query_execution_service.py +++ b/ossdbtoolsservice/query_execution/query_execution_service.py @@ -9,6 +9,7 @@ from typing import Callable, Dict, List # noqa import sqlparse import ntpath +from utils import constants from ossdbtoolsservice.hosting import RequestContext, ServiceProvider @@ -238,12 +239,6 @@ def _batch_execution_started_callback(batch: Batch) -> None: def _batch_execution_finished_callback(batch: Batch) -> None: # Send back notices as a separate message to avoid error coloring / highlighting of text - notices = batch.notices - if notices: - notice_messages = '\n'.join(notices) - notice_message_params = self.build_message_params(worker_args.owner_uri, batch.id, notice_messages, False) - _check_and_fire(worker_args.on_message_notification, notice_message_params) - batch_summary = batch.batch_summary # send query/resultSetComplete response @@ -255,6 +250,12 @@ def _batch_execution_finished_callback(batch: Batch) -> None: rows_message = _create_rows_affected_message(batch) message_params = self.build_message_params(worker_args.owner_uri, batch.id, rows_message, False) _check_and_fire(worker_args.on_message_notification, message_params) + + notices = batch.notices + if notices: + notice_messages = '\n'.join(notices) + notice_message_params = self.build_message_params(worker_args.owner_uri, batch.id, notice_messages, False) + _check_and_fire(worker_args.on_message_notification, notice_message_params) # send query/batchComplete and query/complete response batch_event_params = BatchNotificationParams(batch_summary, worker_args.owner_uri) @@ -368,7 +369,7 @@ def _execute_query_request_worker(self, worker_args: ExecuteRequestWorkerArgs, r self._resolve_query_exception(e, query, worker_args) finally: # Send a query complete notification - batch_summaries = [batch.batch_summary for batch in query.batches] + batch_summaries = [batch.batch_summary for batch in query.batches[:constants.MAX_BATCH_RESULT_MESSAGES]] query_complete_params = QueryCompleteNotificationParams(worker_args.owner_uri, batch_summaries) _check_and_fire(worker_args.on_query_complete, query_complete_params) diff --git a/ossdbtoolsservice/utils/constants.py b/ossdbtoolsservice/utils/constants.py index 054fd0352..8a16ed229 100644 --- a/ossdbtoolsservice/utils/constants.py +++ b/ossdbtoolsservice/utils/constants.py @@ -14,6 +14,10 @@ COSMOS_PG_DEFAULT_DB = "COSMOSPGSQL" PG_DEFAULT_DB = PG_PROVIDER_NAME +# Max number of results that will be shown from a query's batch results. Will show only first 100 result messages. +# TODO: Show LAST 100 result messages +MAX_BATCH_RESULT_MESSAGES = 100 + DEFAULT_DB = { PG_DEFAULT_DB: "postgres", COSMOS_PG_DEFAULT_DB: "citus"