Skip to content
16 changes: 16 additions & 0 deletions ossdbtoolsservice/driver/types/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ def default_database(self) -> str:
def database_error(self) -> Exception:
""" Returns the type of database error this connection throws"""

@property
@abstractmethod
def transaction_is_active(self) -> bool:
"""Returns bool indicating if transaction is active"""

@property
@abstractmethod
def transaction_in_error(self) -> bool:
Expand All @@ -71,6 +76,11 @@ def transaction_in_error(self) -> bool:
def transaction_is_idle(self) -> bool:
"""Returns bool indicating if transaction is currently idle"""

@property
@abstractmethod
def transaction_in_unknown(self) -> bool:
"""Returns bool indicating if transaction is active"""

@property
@abstractmethod
def transaction_in_trans(self) -> bool:
Expand Down Expand Up @@ -179,6 +189,12 @@ def close(self):
Closes this current connection.
"""


def transaction_status(self):
"""
Gets the current transaction status if it exists
"""

@abstractmethod
def set_transaction_in_error(self):
"""
Expand Down
18 changes: 18 additions & 0 deletions ossdbtoolsservice/driver/types/psycopg_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,11 @@ def database_error(self):
"""Returns the type of database error this connection throws"""
return self._database_error

@property
def transaction_is_active(self) -> bool:
"""Returns bool indicating if transaction is active"""
return self._conn.info.transaction_status is TransactionStatus.ACTIVE

@property
def transaction_in_error(self) -> bool:
"""Returns bool indicating if transaction is in error"""
Expand All @@ -159,6 +164,11 @@ def transaction_in_error(self) -> bool:
def transaction_is_idle(self) -> bool:
"""Returns bool indicating if transaction is currently idle"""
return self._conn.info.transaction_status is TransactionStatus.IDLE

@property
def transaction_in_unknown(self) -> bool:
"""Returns bool indicating if transaction is in unknown state"""
return self._conn.info.transaction_status is TransactionStatus.UNKNOWN

@property
def transaction_in_trans(self) -> bool:
Expand Down Expand Up @@ -315,6 +325,14 @@ def close(self):
"""
self._conn.close()

def transaction_status(self):
"""
Gets the current transaction status if it exists
"""
if self._conn and self._conn.info:
return self._conn.info.transaction_status
return None

def set_transaction_in_error(self):
"""
Sets if current connection is in error
Expand Down
63 changes: 10 additions & 53 deletions ossdbtoolsservice/query/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,50 +107,6 @@ def notices(self) -> List[str]:
def get_cursor(self, connection: ServerConnection):
return connection.cursor()

def execute(self, conn: ServerConnection) -> None:
"""
Execute the batch using a cursor retrieved from the given connection

:raises DatabaseError: if an error is encountered while running the batch's query
"""
self._execution_start_time = datetime.now()

if self._batch_events and self._batch_events._on_execution_started:
self._batch_events._on_execution_started(self)

cursor = self.get_cursor(conn)

conn.connection.add_notice_handler(lambda msg: self.notice_handler(msg, conn))

if self.batch_text.startswith('begin') and conn.transaction_in_trans:
self._notices.append('WARNING: there is already a transaction in progress')

try:
cursor.execute(self.batch_text)

# Commit the transaction if autocommit is True
if conn.autocommit:
conn.commit()

self.after_execute(cursor)
except psycopg.DatabaseError as e:
self._has_error = True
conn.set_transaction_in_error()
raise e
finally:
if cursor and cursor.statusmessage is not None:
self.status_message = cursor.statusmessage
# We are doing this because when the execute fails for named cursors
# cursor is not activated on the server which results in failure on close
# Hence we are checking if the cursor was really executed for us to close it
if cursor and cursor.rowcount != -1 and cursor.rowcount is not None:
cursor.close()
self._has_executed = True
self._execution_end_time = datetime.now()

if self._batch_events and self._batch_events._on_execution_completed:
self._batch_events._on_execution_completed(self)

def after_execute(self, cursor) -> None:
if cursor.description is not None:
self.create_result_set(cursor)
Expand Down Expand Up @@ -201,14 +157,15 @@ def create_result_set(storage_type: ResultSetStorageType, result_set_id: int, ba
return InMemoryResultSet(result_set_id, batch_id)


def create_batch(batch_text: str, ordinal: int, selection: SelectionData, batch_events: BatchEvents, storage_type: ResultSetStorageType) -> Batch:
sql = sqlparse.parse(batch_text)
statement = sql[0]

if statement.get_type().lower() == 'select':
into_checker = [True for token in statement.tokens if token.normalized == 'INTO']
cte_checker = [True for token in statement.tokens if token.ttype == sqlparse.tokens.Keyword.CTE]
if len(into_checker) == 0 and len(cte_checker) == 0: # SELECT INTO and CTE keywords can't be used in named cursor
return SelectBatch(batch_text, ordinal, selection, batch_events, storage_type)
def create_batch(batch_text: str, ordinal: int, selection: SelectionData, batch_events: BatchEvents, storage_type: ResultSetStorageType, select_batch: bool = False) -> Batch:
# sql = sqlparse.parse(batch_text)
# statement = sql[0]

# if statement.get_type().lower() == 'select':
# into_checker = [True for token in statement.tokens if token.normalized == 'INTO']
# cte_checker = [True for token in statement.tokens if token.ttype == sqlparse.tokens.Keyword.CTE]
# if len(into_checker) == 0 and len(cte_checker) == 0: # SELECT INTO and CTE keywords can't be used in named cursor
# return SelectBatch(batch_text, ordinal, selection, batch_events, storage_type)
if select_batch:
return SelectBatch(batch_text, ordinal, selection, batch_events, storage_type)
return Batch(batch_text, ordinal, selection, batch_events, storage_type)
172 changes: 128 additions & 44 deletions ossdbtoolsservice/query/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------

from datetime import datetime
from enum import Enum
from typing import Callable, Dict, List, Optional # noqa
from typing import Callable, Dict, List, Optional, Tuple # noqa

import sqlparse
from ossdbtoolsservice.driver import ServerConnection
from ossdbtoolsservice.query import Batch, BatchEvents, create_batch, ResultSetStorageType
from ossdbtoolsservice.query.contracts import SaveResultsRequestParams, SelectionData
from ossdbtoolsservice.query.data_storage import FileStreamFactory
import psycopg
from utils import constants


class QueryEvents:
Expand Down Expand Up @@ -59,43 +62,27 @@ def __init__(self, owner_uri: str, query_text: str, query_execution_settings: Qu
self._user_transaction = False
self._current_batch_index = 0
self._batches: List[Batch] = []
self._notices: List[str] = []
self._execution_plan_options = query_execution_settings.execution_plan_options
self._query_events = query_events
self._query_execution_settings = query_execution_settings

self.is_canceled = False

# Initialize the batches
statements = sqlparse.split(query_text)
selection_data = compute_selection_data_for_batches(statements, query_text)

for index, batch_text in enumerate(statements):
# Skip any empty text
formatted_text = sqlparse.format(batch_text, strip_comments=True).strip()
if not formatted_text or formatted_text == ';':
continue

sql_statement_text = batch_text

# Create and save the batch
if bool(self._execution_plan_options):
if self._execution_plan_options.include_estimated_execution_plan_xml:
sql_statement_text = Query.EXPLAIN_QUERY_TEMPLATE.format(sql_statement_text)
elif self._execution_plan_options.include_actual_execution_plan_xml:
self._disable_auto_commit = True
sql_statement_text = Query.ANALYZE_EXPLAIN_QUERY_TEMPLATE.format(sql_statement_text)

# Check if user defined transaction
if formatted_text.lower().startswith('begin'):
# Use the same selection data for all batches. We want to avoid parsing and splitting into separate SQL statements
self.selection_data = compute_selection_data_for_batches([self.query_text], self.query_text)[0]

# # Create and save the batch
if bool(self._execution_plan_options):
if self._execution_plan_options.include_estimated_execution_plan_xml:
sql_statement_text = Query.EXPLAIN_QUERY_TEMPLATE.format(sql_statement_text)
elif self._execution_plan_options.include_actual_execution_plan_xml:
self._disable_auto_commit = True
self._user_transaction = True

batch = create_batch(
sql_statement_text,
len(self.batches),
selection_data[index],
query_events.batch_events,
query_execution_settings.result_set_storage_type)
sql_statement_text = Query.ANALYZE_EXPLAIN_QUERY_TEMPLATE.format(sql_statement_text)

self._batches.append(batch)
# Check if user defined transaction
if self.query_text.lower().startswith('begin'):
self._disable_auto_commit = True
self._user_transaction = True

@property
def owner_uri(self) -> str:
Expand All @@ -116,6 +103,14 @@ def batches(self) -> List[Batch]:
@property
def current_batch_index(self) -> int:
return self._current_batch_index

@property
def query_events(self) -> QueryEvents:
return self._query_events

@property
def query_execution_settings(self) -> QueryExecutionSettings:
return self._query_execution_settings

def execute(self, connection: ServerConnection, retry_state=False):
"""
Expand All @@ -140,13 +135,56 @@ def execute(self, connection: ServerConnection, retry_state=False):
if self._disable_auto_commit and connection.transaction_is_idle:
connection.autocommit = False

for batch_index, batch in enumerate(self._batches):
self._current_batch_index = batch_index

if self.is_canceled:
break
# Start a cursor block
batch_events: BatchEvents = None
if self.query_events is not None and self.query_events.batch_events is not None:
batch_events = self.query_events.batch_events

connection.connection.add_notice_handler(lambda msg: self.notice_handler(msg, connection))
with connection.cursor() as cur:
start_time = datetime.now()

try:
if self.is_canceled:
return
cur.execute(self.query_text)
end_time = datetime.now()
except psycopg.DatabaseError as e:
end_time = datetime.now()
self.handle_database_error_during_execute(connection, (start_time, end_time), batch_events)
# Exit
raise e

curr_resultset = True
while curr_resultset and len(self.batches) <= constants.MAX_BATCH_RESULT_MESSAGES:
# Break if canceled
if self.is_canceled:
break

# Create and append a new batch object
batch_obj = self.create_next_batch(self.current_batch_index, (start_time, end_time), batch_events)

# Create the result set if necessary and set to _has_executed
batch_obj.after_execute(cur)
batch_obj._has_executed = True

if cur and cur.statusmessage is not None:
batch_obj.status_message = cur.statusmessage

# Update while loop values
curr_resultset = cur.nextset()
self._current_batch_index += 1

# Call Completed callback
if batch_events and batch_events._on_execution_completed:
if not curr_resultset or len(self.batches) >= constants.MAX_BATCH_RESULT_MESSAGES:
batch_obj._notices = self._notices
batch_obj.notices.append(f"WARNING: This query has reached the max limit of {constants.MAX_BATCH_RESULT_MESSAGES} results. The rest of the query has been executed, but furthter results will not be shown")
batch_events._on_execution_completed(batch_obj)
break
else:
batch_events._on_execution_completed(batch_obj)

batch.execute(connection)
finally:
# We can only set autocommit when the connection is open.
if connection.open and connection.transaction_is_idle:
Expand All @@ -155,6 +193,42 @@ def execute(self, connection: ServerConnection, retry_state=False):
self._disable_auto_commit = False
self._execution_state = ExecutionState.EXECUTED

def create_next_batch(self, ordinal: int, execution_times: Tuple[datetime, datetime], batch_events: BatchEvents, empty_selection_data = False):
start_time, end_time = execution_times
batch_obj = create_batch(
self.query_text,
self.current_batch_index,
self.selection_data,
self.query_events.batch_events,
self.query_execution_settings.result_set_storage_type
)
self.batches.append(batch_obj)

# Only set end execution time to first batch summary as we cannot collect individual statement execution times
batch_obj._execution_start_time = start_time
if self.current_batch_index == 0:
batch_obj._execution_end_time = end_time

# Call start callback
if batch_events and batch_events._on_execution_started:
batch_events._on_execution_started(batch_obj)
return batch_obj

def handle_database_error_during_execute(self, conn: ServerConnection, execution_times: Tuple[datetime, datetime], batch_events: BatchEvents):
batch_obj = self.create_next_batch(0, execution_times, batch_events)

batch_obj._has_error = True
self.batches.append(batch_obj)
self._current_batch_index = 0
conn.set_transaction_in_error()

def notice_handler(self, notice: psycopg.errors.Diagnostic, conn: ServerConnection):
# Add notices to last batch element
if not conn.user_transaction:
self._notices.append('{0}: {1}'.format(notice.severity, notice.message_primary))
elif not notice.message_primary == 'there is already a transaction in progress':
self._notices.append('WARNING: {0}'.format(notice.message_primary))

def get_subset(self, batch_index: int, start_index: int, end_index: int):
if batch_index < 0 or batch_index >= len(self._batches):
raise IndexError('Batch index cannot be less than 0 or greater than the number of batches')
Expand All @@ -180,16 +254,26 @@ def compute_selection_data_for_batches(batches: List[str], full_text: str) -> Li
# Iterate through the batches to build selection data
selection_data: List[SelectionData] = []
search_offset = 0
line_map_keys = sorted(line_map.keys())
l, r = 0, 0
start_line_index, end_line_index = 0, 0
for batch in batches:
# Calculate the starting line number and column
start_index = full_text.index(batch, search_offset)
start_line_index = max(filter(lambda line_index: line_index <= start_index, line_map.keys()))
start_line_num = line_map[start_line_index]
start_col_num = start_index - start_line_index
start_index = full_text.index(batch, search_offset) # batch start index
# start_line_index = max(filter(lambda line_index: line_index <= start_index, line_map_keys)) # find the character index of the batch start line
while l < len(line_map_keys) and line_map_keys[l] <= start_index:
start_line_index = line_map_keys[l]
l += 1

start_line_num = line_map[start_line_index] # map that to the line number
start_col_num = start_index - start_line_index

# Calculate the ending line number and column
end_index = start_index + len(batch)
end_line_index = max(filter(lambda line_index: line_index < end_index, line_map.keys()))
# end_line_index = max(filter(lambda line_index: line_index < end_index, line_map_keys))
while r < len(line_map_keys) and line_map_keys[r] < end_index:
end_line_index = line_map_keys[r]
r += 1
end_line_num = line_map[end_line_index]
end_col_num = end_index - end_line_index

Expand Down
Loading