diff --git a/snowflake_utils/models/table.py b/snowflake_utils/models/table.py index 852c2ca..e3775fd 100644 --- a/snowflake_utils/models/table.py +++ b/snowflake_utils/models/table.py @@ -99,10 +99,12 @@ def get_create_temporary_external_stage( def get_create_table_statement( self, full_refresh: bool = False, + copy_grants: bool = True, ) -> str: logging.debug(f"Creating table: {self.fqn}") + copy_grants_clause = " COPY GRANTS" if copy_grants and full_refresh else "" if self.table_structure: - return f"{'CREATE OR REPLACE TABLE' if full_refresh else 'CREATE TABLE IF NOT EXISTS'} {self.fqn} ({self.table_structure.parsed_columns})" + return f"{'CREATE OR REPLACE TABLE' if full_refresh else 'CREATE TABLE IF NOT EXISTS'} {self.fqn}{copy_grants_clause} ({self.table_structure.parsed_columns})" else: template = """ARRAY_AGG( OBJECT_CONSTRUCT( @@ -129,7 +131,7 @@ def get_create_table_statement( stage_query = f"LOCATION => '@{self.stage}'" return f""" - {"CREATE OR REPLACE TABLE" if full_refresh else "CREATE TABLE IF NOT EXISTS"} {self.fqn} + {"CREATE OR REPLACE TABLE" if full_refresh else "CREATE TABLE IF NOT EXISTS"} {self.fqn}{copy_grants_clause} USING TEMPLATE ( SELECT {template} FROM TABLE( @@ -151,7 +153,9 @@ def bulk_insert( cursor = connection.cursor() _execute_statement = partial(execute_statement, cursor) _execute_statement(self.get_create_schema_statement()) - _execute_statement(self.get_create_table_statement(full_refresh)) + _execute_statement( + self.get_create_table_statement(full_refresh, copy_grants=True) + ) for k in records: cols = ", ".join([k for k in records[k].keys()]) vals = ", ".join([_type_cast(v) for v in records[k].values()]) @@ -173,6 +177,7 @@ def _copy( sync_tags: bool = False, stage: str | None = None, create_table: bool = True, + copy_grants: bool = True, ) -> None: with connect() as connection: cursor = connection.cursor() @@ -180,7 +185,7 @@ def _copy( path, storage_integration, cursor, file_format, stage ) if create_table: - self.create_table(full_refresh, execute) + self.create_table(full_refresh, execute, copy_grants) if sync_tags and self.table_structure: self.sync_tags(cursor) @@ -202,6 +207,7 @@ def copy_into( stage: str | None = None, files: list[str] | None = None, create_table: bool = True, + copy_grants: bool = True, ) -> None: col_str = f"({', '.join(target_columns)})" if target_columns else "" files_clause = "" @@ -229,6 +235,7 @@ def copy_into( sync_tags, stage, create_table, + copy_grants, ) with connect() as connection: cursor = connection.cursor() @@ -249,10 +256,13 @@ def copy_into( sync_tags, stage, create_table, + copy_grants, ) - def create_table(self, full_refresh: bool, execute_statement: callable) -> None: - execute_statement(self.get_create_table_statement(full_refresh)) + def create_table( + self, full_refresh: bool, execute_statement: callable, copy_grants: bool = True + ) -> None: + execute_statement(self.get_create_table_statement(full_refresh, copy_grants)) def setup_file_format( self, @@ -311,7 +321,9 @@ def _merge( with connect() as connection: cursor = connection.cursor() - cursor.execute(self.get_create_table_statement(full_refresh=False)) + cursor.execute( + self.get_create_table_statement(full_refresh=False, copy_grants=True) + ) old_columns = {x.name: x.data_type for x in self.get_columns(cursor)} new_columns = temp_table.get_columns(cursor) @@ -338,6 +350,7 @@ def merge( match_by_column_name: MatchByColumnName = MatchByColumnName.CASE_INSENSITIVE, qualify: bool = False, files: list[str] | None = None, + copy_grants: bool = True, ) -> None: def copy_callable(table: Table, sync_tags: bool) -> None: return table.copy_into( @@ -347,6 +360,7 @@ def copy_callable(table: Table, sync_tags: bool) -> None: match_by_column_name=match_by_column_name, sync_tags=sync_tags, files=files, + copy_grants=copy_grants, ) return self._merge(copy_callable, primary_keys, replication_keys, qualify) @@ -567,6 +581,7 @@ def copy_custom( stage: str | None = None, files: list[str] | None = None, create_table: bool = True, + copy_grants: bool = True, ) -> None: column_names = ", ".join(column_definitions.keys()) definitions = ", ".join(column_definitions.values()) @@ -593,6 +608,7 @@ def copy_custom( sync_tags, stage, create_table, + copy_grants, ) def merge_custom( @@ -606,6 +622,7 @@ def merge_custom( qualify: bool = False, files: list[str] | None = None, create_table: bool = True, + copy_grants: bool = True, ) -> None: def copy_callable(table: Table, sync_tags: bool) -> None: return table.copy_custom( @@ -617,6 +634,7 @@ def copy_callable(table: Table, sync_tags: bool) -> None: sync_tags=sync_tags, files=files, create_table=create_table, + copy_grants=copy_grants, ) return self._merge(copy_callable, primary_keys, replication_keys, qualify) diff --git a/tests/test_models.py b/tests/test_models.py index 3dfbfd6..0737750 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -97,6 +97,40 @@ def test_create_or_replace_table(mock_connect): assert result == f"Table {test_table.name} successfully created." +@patch("snowflake_utils.settings.connect") +def test_create_or_replace_table_with_copy_grants(mock_connect): + """Test that COPY GRANTS clause is included when copy_grants=True and full_refresh=True.""" + mock_cursor = make_mock_cursor( + fetchall_return=[(f"Table {test_table.name} successfully created.",)] + ) + mock_conn = make_mock_conn(cursor=mock_cursor) + mock_connect.return_value = mock_conn + statement = test_table.get_create_table_statement( + full_refresh=True, copy_grants=True + ) + result = mock_cursor.execute(statement).fetchall()[0][0] + assert result == f"Table {test_table.name} successfully created." + # Verify that COPY GRANTS is included in the statement + assert "COPY GRANTS" in statement + + +@patch("snowflake_utils.settings.connect") +def test_create_or_replace_table_without_copy_grants(mock_connect): + """Test that COPY GRANTS clause is not included when copy_grants=False and full_refresh=True.""" + mock_cursor = make_mock_cursor( + fetchall_return=[(f"Table {test_table.name} successfully created.",)] + ) + mock_conn = make_mock_conn(cursor=mock_cursor) + mock_connect.return_value = mock_conn + statement = test_table.get_create_table_statement( + full_refresh=True, copy_grants=False + ) + result = mock_cursor.execute(statement).fetchall()[0][0] + assert result == f"Table {test_table.name} successfully created." + # Verify that COPY GRANTS is not included in the statement + assert "COPY GRANTS" not in statement + + @patch("snowflake_utils.settings.connect") def test_create_table_if_not_exists(mock_connect): mock_cursor = make_mock_cursor(fetchall_return=[("statement succeeded: PYTEST",)]) @@ -109,6 +143,23 @@ def test_create_table_if_not_exists(mock_connect): ) +@patch("snowflake_utils.settings.connect") +def test_create_table_if_not_exists_copy_grants_ignored(mock_connect): + """Test that COPY GRANTS clause is not included when full_refresh=False, even if copy_grants=True.""" + mock_cursor = make_mock_cursor(fetchall_return=[("statement succeeded: PYTEST",)]) + mock_conn = make_mock_conn(cursor=mock_cursor) + mock_connect.return_value = mock_conn + statement = test_table.get_create_table_statement( + full_refresh=False, copy_grants=True + ) + result = mock_cursor.execute(statement).fetchall()[0][0] + assert ("statement succeeded" in result and test_table.name in result) or ( + f"Table {test_table.name} successfully created." + ) + # Verify that COPY GRANTS is not included in the statement (only applies to CREATE OR REPLACE) + assert "COPY GRANTS" not in statement + + @patch("snowflake_utils.settings.connect") def test_temporary_external_stage_creation(mock_connect): mock_cursor = make_mock_cursor(