From 1adc9ceb10e4f01b25961f9daad218c46441275d Mon Sep 17 00:00:00 2001 From: LonelyCat124 <3043914+LonelyCat124@users.noreply.github.com.> Date: Mon, 26 Jan 2026 15:07:24 +0000 Subject: [PATCH 01/10] DataNodeExtractTrans implementation --- .../psyir/transformations/__init__.py | 4 + .../transformations/datanode_extract_trans.py | 157 ++++++++++++++++++ .../datanode_extract_trans_test.py | 104 ++++++++++++ 3 files changed, 265 insertions(+) create mode 100644 src/psyclone/psyir/transformations/datanode_extract_trans.py create mode 100644 src/psyclone/tests/psyir/transformations/datanode_extract_trans_test.py diff --git a/src/psyclone/psyir/transformations/__init__.py b/src/psyclone/psyir/transformations/__init__.py index 9f186ee31e..4c821096c1 100644 --- a/src/psyclone/psyir/transformations/__init__.py +++ b/src/psyclone/psyir/transformations/__init__.py @@ -129,6 +129,9 @@ from psyclone.psyir.transformations.omp_parallel_trans import ( OMPParallelTrans, ) +from psyclone.psyir.transformations.datanode_extract_trans import ( + DataNodeExtractTrans +) # For AutoAPI documentation generation __all__ = [ @@ -181,4 +184,5 @@ "OMPDeclareTargetTrans", "OMPCriticalTrans", "OMPParallelTrans", + "DataNodeExtractTrans", ] diff --git a/src/psyclone/psyir/transformations/datanode_extract_trans.py b/src/psyclone/psyir/transformations/datanode_extract_trans.py new file mode 100644 index 0000000000..c7346c6a0e --- /dev/null +++ b/src/psyclone/psyir/transformations/datanode_extract_trans.py @@ -0,0 +1,157 @@ +# ----------------------------------------------------------------------------- +# BSD 3-Clause License +# +# Copyright (c) 2026, Science and Technology Facilities Council. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS +# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE +# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN +# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# ----------------------------------------------------------------------------- +# Authors A. B. G. Chalk STFC Daresbury Lab + +'''This module contains the DataNodeExtractTrans class.''' + +from psyclone.psyGen import Transformation +from psyclone.psyir.transformations import TransformationError +from psyclone.psyir.nodes import DataNode, Reference, Assignment, Statement +from psyclone.psyir.symbols.datatypes import UnresolvedType, ArrayType +from psyclone.psyir.symbols import DataSymbol +from psyclone.utils import transformation_documentation_wrapper + + +@transformation_documentation_wrapper +class DataNodeExtractTrans(Transformation): + """Provides a generic transformation for moving a datanode from a + statement into a new standalone statement. For example: + + >>> from psyclone.psyir.frontend.fortran import FortranReader + >>> from psyclone.psyir.backend.fortran import FortranWriter + >>> from psyclone.psyir.nodes import Assignment + >>> from psyclone.psyir.transformations import DataNodeExtractTrans + >>> + >>> psyir = FortranReader().psyir_from_source(''' + ... subroutine my_subroutine() + ... integer :: i + ... integer :: j + ... i = j * 2 + ... end subroutine + ... ''') + >>> assign = psyir.walk(Assignment)[0] + >>> DataNodeExtractTrans().apply(assign.rhs, storage_name="temp") + >>> print(FortranWriter()(psyir)) + subroutine my_subroutine() + integer, dimension(10,10) :: a + integer :: i + integer :: j + integer :: temp + + temp = j * 2 + i = temp + + end subroutine my_subroutine + + """ + + def validate(self, node: DataNode, **kwargs): + """Validity checks for input arguments + + :param node: The DataNode to be extracted. + + :raises TypeError: if the input arguments are the wrong types. + :raises TransformationError: if the input node's datatype can't be + resolved. + :raises TransformationError: if the input node's datatype is an array + but any of the array's dimensions are + unknown. + """ + # Validate the input options and types. + self.validate_options(**kwargs) + + if isinstance(node.datatype, UnresolvedType): + raise TransformationError( + f"Input node's datatype cannot be computed, so the " + f"DataNodeExtractTrans cannot be applied. Input node was " + f"'{node.debug_string().strip()}'." + ) + + dtype = node.datatype + if isinstance(dtype, ArrayType): + for element in dtype.shape: + if element in [ArrayType.Extent.DEFERRED, + ArrayType.Extent.ATTRIBUTE]: + raise TransformationError( + f"Input node's datatype is an array of unknown size, " + f"so the DataNodeExtractTrans cannot be applied. " + f"Input node was '{node.debug_string().strip()}'." + ) + + def apply(self, node: DataNode, storage_name: str = "", **kwargs): + """Applies the DataNodeExtractTransApplies to the input arguments. + + :param node: The datanode to extract. + :param storage_name: The name of the temporary variable to store + the result of the input node in. The default + is tmp(_...) based on the rules defined + in the SymbolTable class. + """ + # Call validate to check inputs are valid. + self.validate(node, storage_name=storage_name, **kwargs) + + # Find the datatype + datatype = node.datatype + + # Create a symbol of the relevant type. + if not storage_name: + symbol = node.scope.symbol_table.new_symbol( + root_name="tmp", + symbol_type=DataSymbol, + datatype=datatype + ) + else: + symbol = node.scope.symbol_table.new_symbol( + root_name=storage_name, + symbol_type=DataSymbol, + allow_renaming=False, + datatype=datatype + ) + + # Create a Reference to the new symbol + new_ref = Reference(symbol) + + # Find the parent and position of the statement containing the + # DataNode. + parent = node.ancestor(Statement).parent + pos = node.ancestor(Statement).position + + # Replace the datanode with the new reference + node.replace_with(new_ref) + + # Create an assignment to set the value of the new symbol + assign = Assignment.create(new_ref.copy(), node) + + # Add the assignment into the tree. + parent.addchild(assign, pos) diff --git a/src/psyclone/tests/psyir/transformations/datanode_extract_trans_test.py b/src/psyclone/tests/psyir/transformations/datanode_extract_trans_test.py new file mode 100644 index 0000000000..4eefc03c01 --- /dev/null +++ b/src/psyclone/tests/psyir/transformations/datanode_extract_trans_test.py @@ -0,0 +1,104 @@ +# ----------------------------------------------------------------------------- +# BSD 3-Clause License +# +# Copyright (c) 2026, Science and Technology Facilities Council. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS +# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE +# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN +# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# ----------------------------------------------------------------------------- +# Authors A. B. G. Chalk STFC Daresbury Lab + +'''This module contains the DataNodeExtractTrans class.''' + +import pytest + +from psyclone.psyir.nodes import Assignment +from psyclone.psyir.transformations import ( + DataNodeExtractTrans, TransformationError +) + + +def test_datanodeextracttrans_validate(fortran_reader): + """Tests the validate function of the DataNodeExtractTrans.""" + dtrans = DataNodeExtractTrans() + code = """subroutine test(a, b, c) + integer, dimension(:,:), intent(inout) :: a, b, c + c = b + a + end subroutine test""" + psyir = fortran_reader.psyir_from_source(code) + assign = psyir.walk(Assignment)[0] + with pytest.raises(TransformationError) as err: + dtrans.validate(assign.rhs) + assert ("Input node's datatype is an array of unknown size, so the " + "DataNodeExtractTrans cannot be applied. Input node was " + "'b + a'" in str(err.value)) + + code = """subroutine test + use some_mod + c = b + a + end subroutine test""" + psyir = fortran_reader.psyir_from_source(code) + assign = psyir.walk(Assignment)[0] + with pytest.raises(TransformationError) as err: + dtrans.validate(assign.rhs) + assert ("Input node's datatype cannot be computed, so the " + "DataNodeExtractTrans cannot be applied. Input node " + "was 'b + a'" in str(err.value)) + + +def test_datanodeextractrans_apply(fortran_reader, fortran_writer): + """Tests the apply function of the DataNodeExtractTrans.""" + dtrans = DataNodeExtractTrans() + code = """subroutine test() + integer, dimension(10,100) :: a + integer, dimension(100,10) :: b + integer, dimension(10,10) :: c, d + d = c + MATMUL(a, b) + end subroutine test""" + psyir = fortran_reader.psyir_from_source(code) + assign = psyir.walk(Assignment)[0] + dtrans.apply(assign.rhs.operands[1]) + out = fortran_writer(psyir) + assert ("integer, dimension(SIZE(a, dim=1),SIZE(b, dim=2)) :: tmp" + in out) + assert "tmp = MATMUL(a, b)" in out + assert "d = c + tmp" in out + + code = """subroutine test() + real :: a + integer :: b + + b = INT(a) + end subroutine test""" + psyir = fortran_reader.psyir_from_source(code) + assign = psyir.walk(Assignment)[0] + dtrans.apply(assign.rhs, storage_name="temporary") + out = fortran_writer(psyir) + assert "integer :: temporary" in out + assert "temporary = INT(a)" in out + assert "b = temporary" in out From 8c99aa3d3a5e0fec67d8de63b613c1b81f007505 Mon Sep 17 00:00:00 2001 From: LonelyCat124 <3043914+LonelyCat124@users.noreply.github.com.> Date: Tue, 27 Jan 2026 14:51:17 +0000 Subject: [PATCH 02/10] Fixes for review --- .../psyir/transformations/__init__.py | 1 - .../transformations/datanode_extract_trans.py | 57 ++++++++++-- .../datanode_extract_trans_test.py | 93 ++++++++++++++++++- 3 files changed, 138 insertions(+), 13 deletions(-) diff --git a/src/psyclone/psyir/transformations/__init__.py b/src/psyclone/psyir/transformations/__init__.py index 4c821096c1..ebe8208112 100644 --- a/src/psyclone/psyir/transformations/__init__.py +++ b/src/psyclone/psyir/transformations/__init__.py @@ -184,5 +184,4 @@ "OMPDeclareTargetTrans", "OMPCriticalTrans", "OMPParallelTrans", - "DataNodeExtractTrans", ] diff --git a/src/psyclone/psyir/transformations/datanode_extract_trans.py b/src/psyclone/psyir/transformations/datanode_extract_trans.py index c7346c6a0e..b796989327 100644 --- a/src/psyclone/psyir/transformations/datanode_extract_trans.py +++ b/src/psyclone/psyir/transformations/datanode_extract_trans.py @@ -37,8 +37,18 @@ from psyclone.psyGen import Transformation from psyclone.psyir.transformations import TransformationError -from psyclone.psyir.nodes import DataNode, Reference, Assignment, Statement -from psyclone.psyir.symbols.datatypes import UnresolvedType, ArrayType +from psyclone.psyir.nodes import ( + DataNode, + Reference, + Assignment, + Statement, + Call +) +from psyclone.psyir.symbols.datatypes import ( + ArrayType, + UnresolvedType, + UnsupportedFortranType, +) from psyclone.psyir.symbols import DataSymbol from psyclone.utils import transformation_documentation_wrapper @@ -85,20 +95,39 @@ def validate(self, node: DataNode, **kwargs): :raises TransformationError: if the input node's datatype can't be resolved. :raises TransformationError: if the input node's datatype is an array - but any of the array's dimensions are - unknown. + but any of the array's dimensions are unknown. + :raises TransformationError: if the input node doesn't have an + ancestor statement. + :raises TransformationError: if the input node contains a call + that isn't guaranteed to be pure. """ # Validate the input options and types. self.validate_options(**kwargs) - if isinstance(node.datatype, UnresolvedType): + if not isinstance(node, DataNode): + raise TypeError( + f"Input node to DataNodeExtractTrans should be a " + f"DataNode but got '{type(node).__name__}'." + ) + + dtype = node.datatype + + calls = node.walk(Call) + for call in calls: + if not call.is_pure: + raise TransformationError( + f"Input node to DataNodeExtractTrans contains a call " + f"that is not guaranteed to be pure. Input node is " + f"'{node.debug_string().strip()}'." + ) + + if isinstance(dtype, (UnresolvedType, UnsupportedFortranType)): raise TransformationError( f"Input node's datatype cannot be computed, so the " f"DataNodeExtractTrans cannot be applied. Input node was " f"'{node.debug_string().strip()}'." ) - dtype = node.datatype if isinstance(dtype, ArrayType): for element in dtype.shape: if element in [ArrayType.Extent.DEFERRED, @@ -109,14 +138,19 @@ def validate(self, node: DataNode, **kwargs): f"Input node was '{node.debug_string().strip()}'." ) + if node.ancestor(Statement) is None: + raise TransformationError( + "Input node to DataNodeExtractTrans has no ancestor " + "Statement node which is not supported." + ) + def apply(self, node: DataNode, storage_name: str = "", **kwargs): - """Applies the DataNodeExtractTransApplies to the input arguments. + """Applies the DataNodeExtractTrans to the input arguments. :param node: The datanode to extract. :param storage_name: The name of the temporary variable to store - the result of the input node in. The default - is tmp(_...) based on the rules defined - in the SymbolTable class. + the result of the input node in. The default is tmp(_...) + based on the rules defined in the SymbolTable class. """ # Call validate to check inputs are valid. self.validate(node, storage_name=storage_name, **kwargs) @@ -155,3 +189,6 @@ def apply(self, node: DataNode, storage_name: str = "", **kwargs): # Add the assignment into the tree. parent.addchild(assign, pos) + + +__all__ = ["DataNodeExtractTrans"] diff --git a/src/psyclone/tests/psyir/transformations/datanode_extract_trans_test.py b/src/psyclone/tests/psyir/transformations/datanode_extract_trans_test.py index 4eefc03c01..78c7ec1f3d 100644 --- a/src/psyclone/tests/psyir/transformations/datanode_extract_trans_test.py +++ b/src/psyclone/tests/psyir/transformations/datanode_extract_trans_test.py @@ -37,7 +37,12 @@ import pytest -from psyclone.psyir.nodes import Assignment +from psyclone.psyir.nodes import ( + Assignment, Reference +) +from psyclone.psyir.symbols import ( + DataSymbol, INTEGER_TYPE +) from psyclone.psyir.transformations import ( DataNodeExtractTrans, TransformationError ) @@ -68,7 +73,58 @@ def test_datanodeextracttrans_validate(fortran_reader): dtrans.validate(assign.rhs) assert ("Input node's datatype cannot be computed, so the " "DataNodeExtractTrans cannot be applied. Input node " - "was 'b + a'" in str(err.value)) + "was 'b + a'" in str(err.value)) + + code = """subroutine test + character(len=25) :: a, b + + b = a + end subroutine test""" + psyir = fortran_reader.psyir_from_source(code) + assign = psyir.walk(Assignment)[0] + with pytest.raises(TransformationError) as err: + dtrans.validate(assign.rhs) + assert ("Input node's datatype cannot be computed, so the " + "DataNodeExtractTrans cannot be applied. Input node " + "was 'a'" in str(err.value)) + + with pytest.raises(TypeError) as err: + dtrans.validate("abc") + assert ("Input node to DataNodeExtractTrans should be a " + "DataNode but got 'str'." in str(err.value)) + + with pytest.raises(TypeError) as err: + dtrans.validate(assign.rhs, storage_name=1) + assert ("'DataNodeExtractTrans' received options with the wrong types:\n" + "'storage_name' option expects type 'str' but received '1' of " + "type 'int'.\nPlease see the documentation and check the " + "provided types." in str(err.value)) + + with pytest.raises(TransformationError) as err: + dtrans.validate(Reference(DataSymbol("a", INTEGER_TYPE))) + assert ("Input node to DataNodeExtractTrans has no ancestor Statement " + "node which is not supported." in str(err.value)) + + code = """module some_mod + contains + integer function some_func(a, b) + integer :: a, b + a = a + b + some_func = a + b + end function + subroutine test() + integer :: a, b + + a = a + some_func(a,b) + end subroutine test + end module""" + psyir = fortran_reader.psyir_from_source(code) + assign = psyir.walk(Assignment)[2] + with pytest.raises(TransformationError) as err: + dtrans.validate(assign.rhs) + assert ("Input node to DataNodeExtractTrans contains a call that is not " + "guaranteed to be pure. Input node is 'a + some_func(a, b)'." + in str(err.value)) def test_datanodeextractrans_apply(fortran_reader, fortran_writer): @@ -102,3 +158,36 @@ def test_datanodeextractrans_apply(fortran_reader, fortran_writer): assert "integer :: temporary" in out assert "temporary = INT(a)" in out assert "b = temporary" in out + + code = """subroutine test() + real, dimension(100) :: b + integer :: i + + do i = 1, 100 + b(i) = REAL(i) + end do + end subroutine test""" + psyir = fortran_reader.psyir_from_source(code) + assign = psyir.walk(Assignment)[0] + dtrans.apply(assign.rhs, storage_name="temporary") + out = fortran_writer(psyir) + assert " real :: temporary" in out + assert """ do i = 1, 100, 1 + temporary = REAL(i) + b(i) = temporary + enddo""" in out + + code = """subroutine test() + integer, dimension(2:6) :: a + integer, dimension(1:3) :: b + + a(2:4) = 3 * b + + end subroutine test""" + psyir = fortran_reader.psyir_from_source(code) + assign = psyir.walk(Assignment)[0] + dtrans.apply(assign.rhs) + out = fortran_writer(psyir) + assert " integer, dimension(3) :: tmp" in out + assert """ tmp = 3 * b + a(:4) = tmp""" in out From dd30fe986331953ef89aae17ab153052f1832e0d Mon Sep 17 00:00:00 2001 From: LonelyCat124 <3043914+LonelyCat124@users.noreply.github.com.> Date: Wed, 28 Jan 2026 10:37:50 +0000 Subject: [PATCH 03/10] Added test for import resolution --- .../datanode_extract_trans_test.py | 51 ++++++++++++++++++- 1 file changed, 50 insertions(+), 1 deletion(-) diff --git a/src/psyclone/tests/psyir/transformations/datanode_extract_trans_test.py b/src/psyclone/tests/psyir/transformations/datanode_extract_trans_test.py index 78c7ec1f3d..56bbb2e245 100644 --- a/src/psyclone/tests/psyir/transformations/datanode_extract_trans_test.py +++ b/src/psyclone/tests/psyir/transformations/datanode_extract_trans_test.py @@ -35,8 +35,10 @@ '''This module contains the DataNodeExtractTrans class.''' +import os import pytest +from psyclone.configuration import Config from psyclone.psyir.nodes import ( Assignment, Reference ) @@ -127,7 +129,8 @@ def test_datanodeextracttrans_validate(fortran_reader): in str(err.value)) -def test_datanodeextractrans_apply(fortran_reader, fortran_writer): +def test_datanodeextractrans_apply(fortran_reader, fortran_writer, tmpdir, + monkeypatch): """Tests the apply function of the DataNodeExtractTrans.""" dtrans = DataNodeExtractTrans() code = """subroutine test() @@ -191,3 +194,49 @@ def test_datanodeextractrans_apply(fortran_reader, fortran_writer): assert " integer, dimension(3) :: tmp" in out assert """ tmp = 3 * b a(:4) = tmp""" in out + + # Test the imports are handled correctly. + monkeypatch.setattr(Config.get(), '_include_paths', [str(tmpdir)]) + filename = os.path.join(str(tmpdir), "a_mod.f90") + with open(filename, "w", encoding='UTF-8') as module: + module.write(''' + module a_mod + integer :: some_var + end module a_mod + ''') + code = """subroutine test() + use a_mod + integer :: b + b = some_var + end subroutine test""" + psyir = fortran_reader.psyir_from_source(code) + psyir.children[0].symbol_table.resolve_imports() + assign = psyir.walk(Assignment)[0] + dtrans.apply(assign.rhs) + out = fortran_writer(psyir) + assert """ integer :: tmp + + tmp = some_var + b = tmp""" in out + + filename = os.path.join(str(tmpdir), "b_mod.f90") + with open(filename, "w", encoding='UTF-8') as module: + module.write(''' + module b_mod + integer, dimension(25, 50) :: some_var + end module b_mod + ''') + code = """subroutine test() + use b_mod + integer, dimension(25, 50) :: b + b = some_var + end subroutine test""" + psyir = fortran_reader.psyir_from_source(code) + psyir.children[0].symbol_table.resolve_imports() + assign = psyir.walk(Assignment)[0] + dtrans.apply(assign.rhs) + out = fortran_writer(psyir) + assert """ integer, dimension(25,50) :: tmp + + tmp = some_var + b = tmp""" in out From f81a2c3673aaf5882b4069c9e0dbceabd085b048 Mon Sep 17 00:00:00 2001 From: LonelyCat124 <3043914+LonelyCat124@users.noreply.github.com.> Date: Thu, 29 Jan 2026 15:12:16 +0000 Subject: [PATCH 04/10] Some more changes towards review --- .../transformations/datanode_extract_trans.py | 50 ++++++++++++++++--- .../datanode_extract_trans_test.py | 49 +++++++++++++++++- 2 files changed, 90 insertions(+), 9 deletions(-) diff --git a/src/psyclone/psyir/transformations/datanode_extract_trans.py b/src/psyclone/psyir/transformations/datanode_extract_trans.py index b796989327..0a997443ba 100644 --- a/src/psyclone/psyir/transformations/datanode_extract_trans.py +++ b/src/psyclone/psyir/transformations/datanode_extract_trans.py @@ -120,14 +120,6 @@ def validate(self, node: DataNode, **kwargs): f"that is not guaranteed to be pure. Input node is " f"'{node.debug_string().strip()}'." ) - - if isinstance(dtype, (UnresolvedType, UnsupportedFortranType)): - raise TransformationError( - f"Input node's datatype cannot be computed, so the " - f"DataNodeExtractTrans cannot be applied. Input node was " - f"'{node.debug_string().strip()}'." - ) - if isinstance(dtype, ArrayType): for element in dtype.shape: if element in [ArrayType.Extent.DEFERRED, @@ -137,6 +129,22 @@ def validate(self, node: DataNode, **kwargs): f"so the DataNodeExtractTrans cannot be applied. " f"Input node was '{node.debug_string().strip()}'." ) + # Otherwise we have an ArrayBounds + symbols = set() + if isinstance(element.lower, DataNode): + symbols.update(element.lower.get_all_accessed_symbols()) + if isinstance(element.upper, DataNode): + symbols.update(element.upper.get_all_accessed_symbols()) + scope_symbols = node.scope.symbol_table.get_symbols() + for sym in symbols: + scoped_name_sym = scope_symbols.get(sym.name, None) + if scoped_name_sym and not sym is scoped_name_sym: + raise TransformationError( + f"Input node contains an imported symbol whose " + f"name collides with an existing symbol, so the " + f"DataNodeExtractTrans cannot be applied. " + f"Clashing symbol name is '{sym.name}'." + ) if node.ancestor(Statement) is None: raise TransformationError( @@ -144,6 +152,13 @@ def validate(self, node: DataNode, **kwargs): "Statement node which is not supported." ) + if isinstance(dtype, (UnresolvedType, UnsupportedFortranType)): + raise TransformationError( + f"Input node's datatype cannot be computed, so the " + f"DataNodeExtractTrans cannot be applied. Input node was " + f"'{node.debug_string().strip()}'." + ) + def apply(self, node: DataNode, storage_name: str = "", **kwargs): """Applies the DataNodeExtractTrans to the input arguments. @@ -173,6 +188,25 @@ def apply(self, node: DataNode, storage_name: str = "", **kwargs): datatype=datatype ) + # FIXME Make sure the shape is all in the symbol table. We know that + # all symbols we find can be safely added as otherwise validate will + # fail. + # This is an oversimplification because we could have multiple + # references to the same symbol... + if isinstance(datatype, ArrayType): + for element in dtype.shape: + symbols = set() + if isinstance(element.lower, DataNode): + symbols.update(element.lower.get_all_accessed_symbols()) + if isinstance(element.upper, DataNode): + symbols.update(element.upper.get_all_accessed_symbols()) + scope_symbols = node.scope.symbol_table.get_symbols() + for sym in symbols: + scoped_name_sym = scope_symbols.get(sym.name, None) + if not scoped_name_sym: + sym_copy = symbol.copy() + node.scope.symbol_table.add(sym_copy) + # Create a Reference to the new symbol new_ref = Reference(symbol) diff --git a/src/psyclone/tests/psyir/transformations/datanode_extract_trans_test.py b/src/psyclone/tests/psyir/transformations/datanode_extract_trans_test.py index 56bbb2e245..ec06e9acc0 100644 --- a/src/psyclone/tests/psyir/transformations/datanode_extract_trans_test.py +++ b/src/psyclone/tests/psyir/transformations/datanode_extract_trans_test.py @@ -50,7 +50,7 @@ ) -def test_datanodeextracttrans_validate(fortran_reader): +def test_datanodeextracttrans_validate(fortran_reader, tmpdir, monkeypatch): """Tests the validate function of the DataNodeExtractTrans.""" dtrans = DataNodeExtractTrans() code = """subroutine test(a, b, c) @@ -128,6 +128,30 @@ def test_datanodeextracttrans_validate(fortran_reader): "guaranteed to be pure. Input node is 'a + some_func(a, b)'." in str(err.value)) + monkeypatch.setattr(Config.get(), '_include_paths', [str(tmpdir)]) + filename = os.path.join(str(tmpdir), "a_mod.f90") + with open(filename, "w", encoding='UTF-8') as module: + module.write(''' + module a_mod + use some_mod, only: i + integer, dimension(25, i) :: some_var + end module a_mod + ''') + code = """subroutine test() + use a_mod, only: some_var + integer, dimension(25, 50) :: b + integer :: i + b = some_var + end subroutine test""" + psyir = fortran_reader.psyir_from_source(code) + psyir.children[0].symbol_table.resolve_imports() + assign = psyir.walk(Assignment)[0] + with pytest.raises(TransformationError) as err: + dtrans.validate(assign.rhs) + assert ("Input node contains an imported symbol whose name collides " + "with an existing symbol, so the DataNodeExtractTrans cannot be " + "applied. Clashing symbol name is 'i'." in str(err.value)) + def test_datanodeextractrans_apply(fortran_reader, fortran_writer, tmpdir, monkeypatch): @@ -240,3 +264,26 @@ def test_datanodeextractrans_apply(fortran_reader, fortran_writer, tmpdir, tmp = some_var b = tmp""" in out + + filename = os.path.join(str(tmpdir), "c_mod.f90") + with open(filename, "w", encoding='UTF-8') as module: + module.write(''' + module c_mod + use some_mod, only: i + integer, dimension(25, i) :: some_var + end module c_mod + ''') + code = """subroutine test() + use c_mod, only: some_var + integer, dimension(25, 50) :: b + b = some_var + end subroutine test""" + psyir = fortran_reader.psyir_from_source(code) + psyir.children[0].symbol_table.resolve_imports() + assign = psyir.walk(Assignment)[0] + dtrans.apply(assign.rhs) + out = fortran_writer(psyir) + print(psyir.walk(Assignment)[0].rhs.symbol.shape[1].upper.symbol.interface) + print(psyir.walk(Assignment)[0].rhs.symbol.shape[1].upper.symbol.is_unknown_interface) + print(out) + assert False From 731bb59bc7f552149616e0ad5de9932cc715302e Mon Sep 17 00:00:00 2001 From: LonelyCat124 <3043914+LonelyCat124@users.noreply.github.com.> Date: Mon, 2 Feb 2026 14:07:32 +0000 Subject: [PATCH 05/10] Interface handling --- .../transformations/datanode_extract_trans.py | 41 ++++++++++++++++--- .../datanode_extract_trans_test.py | 4 -- 2 files changed, 36 insertions(+), 9 deletions(-) diff --git a/src/psyclone/psyir/transformations/datanode_extract_trans.py b/src/psyclone/psyir/transformations/datanode_extract_trans.py index 0a997443ba..3c239d6b65 100644 --- a/src/psyclone/psyir/transformations/datanode_extract_trans.py +++ b/src/psyclone/psyir/transformations/datanode_extract_trans.py @@ -49,7 +49,7 @@ UnresolvedType, UnsupportedFortranType, ) -from psyclone.psyir.symbols import DataSymbol +from psyclone.psyir.symbols import DataSymbol, ImportInterface from psyclone.utils import transformation_documentation_wrapper @@ -138,7 +138,19 @@ def validate(self, node: DataNode, **kwargs): scope_symbols = node.scope.symbol_table.get_symbols() for sym in symbols: scoped_name_sym = scope_symbols.get(sym.name, None) - if scoped_name_sym and not sym is scoped_name_sym: + if scoped_name_sym and sym is not scoped_name_sym: + # If its an imported symbol we need to check if its + # the same import interface. + if (isinstance(sym.interface, ImportInterface) and + isinstance(scoped_name_sym.interface, + ImportInterface)): + # If they have the same container symbol name + # then its fine, otherwise we fall into the + # TransformationError + if (sym.interface.container_symbol.name == + scoped_name_sym.interface. + container_symbol.name): + continue raise TransformationError( f"Input node contains an imported symbol whose " f"name collides with an existing symbol, so the " @@ -194,7 +206,7 @@ def apply(self, node: DataNode, storage_name: str = "", **kwargs): # This is an oversimplification because we could have multiple # references to the same symbol... if isinstance(datatype, ArrayType): - for element in dtype.shape: + for element in datatype.shape: symbols = set() if isinstance(element.lower, DataNode): symbols.update(element.lower.get_all_accessed_symbols()) @@ -203,9 +215,28 @@ def apply(self, node: DataNode, storage_name: str = "", **kwargs): scope_symbols = node.scope.symbol_table.get_symbols() for sym in symbols: scoped_name_sym = scope_symbols.get(sym.name, None) + # If no symbol with the name exists then create one. if not scoped_name_sym: - sym_copy = symbol.copy() - node.scope.symbol_table.add(sym_copy) + sym_copy = symbol.copy() + if isinstance(sym_copy.interface, ImportInterface): + # Check if the ContainerSymbol is already in the + # interface + container = scope_symbols.get( + sym_copy.interface.container_symbol.name, + None + ) + if container is None: + # Add the container symbol the the symbol table + # and we're ok with this symbol. + node.scope.symbol_table.add( + sym_copy.interface.container_symbol + ) + # If we find the container then we need to update + # the interface to use the container listed. + else: + sym_copy.interface.container_symbol = \ + container + node.scope.symbol_table.add(sym_copy) # Create a Reference to the new symbol new_ref = Reference(symbol) diff --git a/src/psyclone/tests/psyir/transformations/datanode_extract_trans_test.py b/src/psyclone/tests/psyir/transformations/datanode_extract_trans_test.py index ec06e9acc0..6ba88714f9 100644 --- a/src/psyclone/tests/psyir/transformations/datanode_extract_trans_test.py +++ b/src/psyclone/tests/psyir/transformations/datanode_extract_trans_test.py @@ -283,7 +283,3 @@ def test_datanodeextractrans_apply(fortran_reader, fortran_writer, tmpdir, assign = psyir.walk(Assignment)[0] dtrans.apply(assign.rhs) out = fortran_writer(psyir) - print(psyir.walk(Assignment)[0].rhs.symbol.shape[1].upper.symbol.interface) - print(psyir.walk(Assignment)[0].rhs.symbol.shape[1].upper.symbol.is_unknown_interface) - print(out) - assert False From ffa39bc956584391b2566e53963484e20e04303d Mon Sep 17 00:00:00 2001 From: LonelyCat124 <3043914+LonelyCat124@users.noreply.github.com.> Date: Mon, 2 Feb 2026 16:00:53 +0000 Subject: [PATCH 06/10] Fixed incorrect behaviour due to wrong variable usage --- src/psyclone/psyir/transformations/datanode_extract_trans.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/psyclone/psyir/transformations/datanode_extract_trans.py b/src/psyclone/psyir/transformations/datanode_extract_trans.py index 3c239d6b65..f791c64078 100644 --- a/src/psyclone/psyir/transformations/datanode_extract_trans.py +++ b/src/psyclone/psyir/transformations/datanode_extract_trans.py @@ -217,7 +217,7 @@ def apply(self, node: DataNode, storage_name: str = "", **kwargs): scoped_name_sym = scope_symbols.get(sym.name, None) # If no symbol with the name exists then create one. if not scoped_name_sym: - sym_copy = symbol.copy() + sym_copy = sym.copy() if isinstance(sym_copy.interface, ImportInterface): # Check if the ContainerSymbol is already in the # interface From b3a90fcdb97728b6abdad1270853d05b3d1e934d Mon Sep 17 00:00:00 2001 From: LonelyCat124 <3043914+LonelyCat124@users.noreply.github.com.> Date: Tue, 3 Feb 2026 10:49:38 +0000 Subject: [PATCH 07/10] Missing coverage fixes --- .../datanode_extract_trans_test.py | 73 +++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/src/psyclone/tests/psyir/transformations/datanode_extract_trans_test.py b/src/psyclone/tests/psyir/transformations/datanode_extract_trans_test.py index 6ba88714f9..eb32984db4 100644 --- a/src/psyclone/tests/psyir/transformations/datanode_extract_trans_test.py +++ b/src/psyclone/tests/psyir/transformations/datanode_extract_trans_test.py @@ -39,6 +39,7 @@ import pytest from psyclone.configuration import Config +from psyclone.psyir.frontend.fortran import FortranReader from psyclone.psyir.nodes import ( Assignment, Reference ) @@ -152,6 +153,35 @@ def test_datanodeextracttrans_validate(fortran_reader, tmpdir, monkeypatch): "with an existing symbol, so the DataNodeExtractTrans cannot be " "applied. Clashing symbol name is 'i'." in str(err.value)) + # Check validation works when the shape contains a symbol from an + # existing module + filename = os.path.join(str(tmpdir), "a_mod.f90") + with open(filename, "w", encoding='UTF-8') as module: + module.write(''' + module a_mod + use some_mod, only: i + integer, dimension(25, i) :: some_var + end module a_mod + ''') + filename = os.path.join(str(tmpdir), "some_mod.f90") + with open(filename, "w", encoding='UTF-8') as module: + module.write(''' + module some_mod + integer, parameter :: i = 25 + integer, parameter :: j = 30 + end module some_mod + ''') + code = """subroutine test() + use a_mod, only: some_var + use some_mod, only: j + j = some_var(1,3) + end subroutine test""" + # We need to resolve the module in the Frontend to avoid some_Var + # becoming a call. + psyir = FortranReader(resolve_modules=True).psyir_from_source(code) + assign = psyir.walk(Assignment)[0] + dtrans.validate(assign.rhs) + def test_datanodeextractrans_apply(fortran_reader, fortran_writer, tmpdir, monkeypatch): @@ -283,3 +313,46 @@ def test_datanodeextractrans_apply(fortran_reader, fortran_writer, tmpdir, assign = psyir.walk(Assignment)[0] dtrans.apply(assign.rhs) out = fortran_writer(psyir) + assert """ use some_mod, only : i + integer, dimension(25,50) :: b + integer, dimension(25,i) :: tmp + + tmp = some_var + b = tmp""" in out + + # Check that modules in a shape from an imported module are + # correctly added to the output if the module is already + # present as a Container. + filename = os.path.join(str(tmpdir), "f_mod.f90") + with open(filename, "w", encoding='UTF-8') as module: + module.write(''' + module f_mod + use g_mod, only: i + integer, dimension(25, i) :: some_var + end module f_mod + ''') + filename = os.path.join(str(tmpdir), "g_mod.f90") + with open(filename, "w", encoding='UTF-8') as module: + module.write(''' + module g_mod + integer, parameter :: i = 25 + integer, dimension(25, i) :: j = 30 + end module g_mod + ''') + code = """subroutine test() + use g_mod, only: j + use f_mod, only: some_var + j = some_var + end subroutine test""" + # We need to resolve the module in the Frontend to avoid some_Var + # becoming a call. + psyir = FortranReader(resolve_modules=True).psyir_from_source(code) + assign = psyir.walk(Assignment)[0] + dtrans.apply(assign.rhs) + out = fortran_writer(psyir) + assert """ use g_mod, only : i, j + use f_mod, only : some_var + integer, dimension(25,i) :: tmp + + tmp = some_var + j = tmp""" in out From 5c5fea5626b154b4ed9378802dd640c16455be0c Mon Sep 17 00:00:00 2001 From: LonelyCat124 <3043914+LonelyCat124@users.noreply.github.com.> Date: Tue, 3 Feb 2026 13:27:25 +0000 Subject: [PATCH 08/10] fixed error in validate and coverage --- .../transformations/datanode_extract_trans.py | 30 ++++++++++-------- .../datanode_extract_trans_test.py | 31 +++++++++++++++++++ 2 files changed, 48 insertions(+), 13 deletions(-) diff --git a/src/psyclone/psyir/transformations/datanode_extract_trans.py b/src/psyclone/psyir/transformations/datanode_extract_trans.py index f791c64078..58f6ae0e26 100644 --- a/src/psyclone/psyir/transformations/datanode_extract_trans.py +++ b/src/psyclone/psyir/transformations/datanode_extract_trans.py @@ -49,7 +49,8 @@ UnresolvedType, UnsupportedFortranType, ) -from psyclone.psyir.symbols import DataSymbol, ImportInterface +from psyclone.psyir.symbols import ( + DataSymbol, ImportInterface, ContainerSymbol) from psyclone.utils import transformation_documentation_wrapper @@ -139,24 +140,27 @@ def validate(self, node: DataNode, **kwargs): for sym in symbols: scoped_name_sym = scope_symbols.get(sym.name, None) if scoped_name_sym and sym is not scoped_name_sym: - # If its an imported symbol we need to check if its - # the same import interface. - if (isinstance(sym.interface, ImportInterface) and - isinstance(scoped_name_sym.interface, - ImportInterface)): - # If they have the same container symbol name - # then its fine, otherwise we fall into the - # TransformationError - if (sym.interface.container_symbol.name == - scoped_name_sym.interface. - container_symbol.name): - continue raise TransformationError( f"Input node contains an imported symbol whose " f"name collides with an existing symbol, so the " f"DataNodeExtractTrans cannot be applied. " f"Clashing symbol name is '{sym.name}'." ) + # If its an imported symbol we need to check if its + # the same import interface. + if isinstance(sym.interface, ImportInterface): + scoped_name_sym = scope_symbols.get( + sym.interface.container_symbol.name, + None + ) + if scoped_name_sym and not isinstance( + scoped_name_sym, ContainerSymbol): + raise TransformationError( + f"Input node contains an imported symbol " + f"whose containing module collides with an " + f"existing symbol. Colliding name is " + f"'{sym.interface.container_symbol.name}'." + ) if node.ancestor(Statement) is None: raise TransformationError( diff --git a/src/psyclone/tests/psyir/transformations/datanode_extract_trans_test.py b/src/psyclone/tests/psyir/transformations/datanode_extract_trans_test.py index eb32984db4..33ec4fa37e 100644 --- a/src/psyclone/tests/psyir/transformations/datanode_extract_trans_test.py +++ b/src/psyclone/tests/psyir/transformations/datanode_extract_trans_test.py @@ -182,6 +182,37 @@ def test_datanodeextracttrans_validate(fortran_reader, tmpdir, monkeypatch): assign = psyir.walk(Assignment)[0] dtrans.validate(assign.rhs) + # Check validation raise an error when the shape contains a symbol from + # a module that overlaps with a symbol in the scope. + filename = os.path.join(str(tmpdir), "tmpmod.f90") + with open(filename, "w", encoding='UTF-8') as module: + module.write(''' + module tmpmod + integer, parameter :: i = 25 + integer, parameter :: j = 30 + end module tmpmod + ''') + filename = os.path.join(str(tmpdir), "f_mod.f90") + with open(filename, "w", encoding='UTF-8') as module: + module.write(''' + module f_mod + use tmpmod, only: i + integer, dimension(25, i) :: some_var + end module f_mod + ''') + code = """subroutine test() + use f_mod, only: some_var + integer :: tmpmod + tmpmod = some_var + end subroutine test""" + psyir = FortranReader(resolve_modules=True).psyir_from_source(code) + assign = psyir.walk(Assignment)[0] + with pytest.raises(TransformationError) as err: + dtrans.validate(assign.rhs) + assert ("Input node contains an imported symbol whose containing module " + "collides with an existing symbol. Colliding name is 'tmpmod'." + in str(err.value)) + def test_datanodeextractrans_apply(fortran_reader, fortran_writer, tmpdir, monkeypatch): From 835fb3bcf6cb57590842683f08e7059db724ff31 Mon Sep 17 00:00:00 2001 From: LonelyCat124 <3043914+LonelyCat124@users.noreply.github.com.> Date: Fri, 6 Feb 2026 15:16:13 +0000 Subject: [PATCH 09/10] Changes for review --- .../psyir/transformations/__init__.py | 4 +- ...act_trans.py => datanode_to_temp_trans.py} | 128 ++++++++++----- ...test.py => datanode_to_temp_trans_test.py} | 146 +++++++++++++----- 3 files changed, 199 insertions(+), 79 deletions(-) rename src/psyclone/psyir/transformations/{datanode_extract_trans.py => datanode_to_temp_trans.py} (69%) rename src/psyclone/tests/psyir/transformations/{datanode_extract_trans_test.py => datanode_to_temp_trans_test.py} (71%) diff --git a/src/psyclone/psyir/transformations/__init__.py b/src/psyclone/psyir/transformations/__init__.py index ebe8208112..5b3a3f8386 100644 --- a/src/psyclone/psyir/transformations/__init__.py +++ b/src/psyclone/psyir/transformations/__init__.py @@ -129,8 +129,8 @@ from psyclone.psyir.transformations.omp_parallel_trans import ( OMPParallelTrans, ) -from psyclone.psyir.transformations.datanode_extract_trans import ( - DataNodeExtractTrans +from psyclone.psyir.transformations.datanode_to_temp_trans import ( + DataNodeToTempTrans ) # For AutoAPI documentation generation diff --git a/src/psyclone/psyir/transformations/datanode_extract_trans.py b/src/psyclone/psyir/transformations/datanode_to_temp_trans.py similarity index 69% rename from src/psyclone/psyir/transformations/datanode_extract_trans.py rename to src/psyclone/psyir/transformations/datanode_to_temp_trans.py index 58f6ae0e26..ae57417ed5 100644 --- a/src/psyclone/psyir/transformations/datanode_extract_trans.py +++ b/src/psyclone/psyir/transformations/datanode_to_temp_trans.py @@ -33,7 +33,7 @@ # ----------------------------------------------------------------------------- # Authors A. B. G. Chalk STFC Daresbury Lab -'''This module contains the DataNodeExtractTrans class.''' +'''This module contains the DataNodeToTempTrans class.''' from psyclone.psyGen import Transformation from psyclone.psyir.transformations import TransformationError @@ -49,20 +49,24 @@ UnresolvedType, UnsupportedFortranType, ) +from psyclone.psyir.symbols.interfaces import ( + UnresolvedInterface, + UnknownInterface +) from psyclone.psyir.symbols import ( - DataSymbol, ImportInterface, ContainerSymbol) + DataSymbol, ImportInterface, ContainerSymbol, Symbol) from psyclone.utils import transformation_documentation_wrapper @transformation_documentation_wrapper -class DataNodeExtractTrans(Transformation): +class DataNodeToTempTrans(Transformation): """Provides a generic transformation for moving a datanode from a statement into a new standalone statement. For example: >>> from psyclone.psyir.frontend.fortran import FortranReader >>> from psyclone.psyir.backend.fortran import FortranWriter >>> from psyclone.psyir.nodes import Assignment - >>> from psyclone.psyir.transformations import DataNodeExtractTrans + >>> from psyclone.psyir.transformations import DataNodeToTempTrans >>> >>> psyir = FortranReader().psyir_from_source(''' ... subroutine my_subroutine() @@ -72,7 +76,7 @@ class DataNodeExtractTrans(Transformation): ... end subroutine ... ''') >>> assign = psyir.walk(Assignment)[0] - >>> DataNodeExtractTrans().apply(assign.rhs, storage_name="temp") + >>> DataNodeToTempTrans().apply(assign.rhs, storage_name="temp") >>> print(FortranWriter()(psyir)) subroutine my_subroutine() integer, dimension(10,10) :: a @@ -107,7 +111,7 @@ def validate(self, node: DataNode, **kwargs): if not isinstance(node, DataNode): raise TypeError( - f"Input node to DataNodeExtractTrans should be a " + f"Input node to DataNodeToTempTrans should be a " f"DataNode but got '{type(node).__name__}'." ) @@ -117,8 +121,9 @@ def validate(self, node: DataNode, **kwargs): for call in calls: if not call.is_pure: raise TransformationError( - f"Input node to DataNodeExtractTrans contains a call " - f"that is not guaranteed to be pure. Input node is " + f"Input node to DataNodeToTempTrans contains a call " + f"'{call.debug_string().strip()}' that is not guaranteed " + f"to be pure. Input node is " f"'{node.debug_string().strip()}'." ) if isinstance(dtype, ArrayType): @@ -127,24 +132,46 @@ def validate(self, node: DataNode, **kwargs): ArrayType.Extent.ATTRIBUTE]: raise TransformationError( f"Input node's datatype is an array of unknown size, " - f"so the DataNodeExtractTrans cannot be applied. " + f"so the DataNodeToTempTrans cannot be applied. " f"Input node was '{node.debug_string().strip()}'." ) - # Otherwise we have an ArrayBounds + # The shape must now be set by ArrayBounds, we need to + # examine the symbols used to define those bounds. symbols = set() if isinstance(element.lower, DataNode): symbols.update(element.lower.get_all_accessed_symbols()) if isinstance(element.upper, DataNode): symbols.update(element.upper.get_all_accessed_symbols()) + # Compare the symbols in the array bounds with the symbols + # already in the scope. scope_symbols = node.scope.symbol_table.get_symbols() for sym in symbols: scoped_name_sym = scope_symbols.get(sym.name, None) + # If sym is not scoped_name_sym, then there is a + # symbol collision from an imported symbol. if scoped_name_sym and sym is not scoped_name_sym: + # If the symbol in scoped is imported from the same + # container then we can skip this. + if (isinstance(scoped_name_sym.interface, + ImportInterface) and + (scoped_name_sym.interface.container_symbol.name + == sym.interface.container_symbol.name)): + continue + raise TransformationError( + f"The type of the node supplied to {self.name} " + f"depends upon an imported symbol '{sym.name}' " + f"which has a name clash with a symbol in the " + f"current scope." + ) + # If its not in the current scope, and its visibility is + # private then we can't import it. + if (not scoped_name_sym and sym.visibility == + Symbol.Visibility.PRIVATE): raise TransformationError( - f"Input node contains an imported symbol whose " - f"name collides with an existing symbol, so the " - f"DataNodeExtractTrans cannot be applied. " - f"Clashing symbol name is '{sym.name}'." + f"The datatype of the node suppled to " + f"{self.name} depends upon an imported symbol " + f"'{sym.name}' that is declared as private in " + f"its containing module, so cannot be imported." ) # If its an imported symbol we need to check if its # the same import interface. @@ -157,29 +184,46 @@ def validate(self, node: DataNode, **kwargs): scoped_name_sym, ContainerSymbol): raise TransformationError( f"Input node contains an imported symbol " - f"whose containing module collides with an " - f"existing symbol. Colliding name is " + f"'{sym.name}' whose containing module " + f"collides with an existing symbol. Colliding " + f"name is " f"'{sym.interface.container_symbol.name}'." ) if node.ancestor(Statement) is None: raise TransformationError( - "Input node to DataNodeExtractTrans has no ancestor " + "Input node to DataNodeToTempTrans has no ancestor " "Statement node which is not supported." ) if isinstance(dtype, (UnresolvedType, UnsupportedFortranType)): - raise TransformationError( + failing_symbols = [] + symbols = node.get_all_accessed_symbols() + for sym in symbols: + if isinstance(sym.interface, (UnresolvedInterface, + UnknownInterface)): + failing_symbols.append(sym.name) + # Sort the order of the list to get consistant results for tests. + failing_symbols.sort() + message = ( f"Input node's datatype cannot be computed, so the " - f"DataNodeExtractTrans cannot be applied. Input node was " + f"DataNodeToTempTrans cannot be applied. Input node was " f"'{node.debug_string().strip()}'." ) + if failing_symbols: + message += ( + f" The following symbols in the input node are not " + f"resolved in the scope: '{failing_symbols}'. Setting " + f"RESOLVE_IMPORTS in the transformation script may " + f"enable resolution of these symbols." + ) + raise TransformationError(message) def apply(self, node: DataNode, storage_name: str = "", **kwargs): - """Applies the DataNodeExtractTrans to the input arguments. + """Applies the DataNodeToTempTrans to the input arguments. :param node: The datanode to extract. - :param storage_name: The name of the temporary variable to store + :param storage_name: The base name of the temporary variable to store the result of the input node in. The default is tmp(_...) based on the rules defined in the SymbolTable class. """ @@ -189,26 +233,14 @@ def apply(self, node: DataNode, storage_name: str = "", **kwargs): # Find the datatype datatype = node.datatype - # Create a symbol of the relevant type. - if not storage_name: - symbol = node.scope.symbol_table.new_symbol( - root_name="tmp", - symbol_type=DataSymbol, - datatype=datatype - ) - else: - symbol = node.scope.symbol_table.new_symbol( - root_name=storage_name, - symbol_type=DataSymbol, - allow_renaming=False, - datatype=datatype - ) - - # FIXME Make sure the shape is all in the symbol table. We know that + # Make sure the shape is all in the symbol table. We know that # all symbols we find can be safely added as otherwise validate will # fail. - # This is an oversimplification because we could have multiple - # references to the same symbol... + # Symbols used to reference shapes that are from imported modules but + # that aren't currently in the symbol table will be placed into the + # symbol table with a corresponding ImportInterface so the resultant + # symbol will reference the original definition of the shape in the + # containing module. if isinstance(datatype, ArrayType): for element in datatype.shape: symbols = set() @@ -241,7 +273,23 @@ def apply(self, node: DataNode, storage_name: str = "", **kwargs): sym_copy.interface.container_symbol = \ container node.scope.symbol_table.add(sym_copy) + # Now we've created the relevant symbols, we need to update + # the datatype to use the in-scope symbols + datatype.replace_symbols_using(node.scope.symbol_table) + # Create a symbol of the relevant type. + if not storage_name: + symbol = node.scope.symbol_table.new_symbol( + root_name="tmp", + symbol_type=DataSymbol, + datatype=datatype + ) + else: + symbol = node.scope.symbol_table.new_symbol( + root_name=storage_name, + symbol_type=DataSymbol, + datatype=datatype + ) # Create a Reference to the new symbol new_ref = Reference(symbol) @@ -260,4 +308,4 @@ def apply(self, node: DataNode, storage_name: str = "", **kwargs): parent.addchild(assign, pos) -__all__ = ["DataNodeExtractTrans"] +__all__ = ["DataNodeToTempTrans"] diff --git a/src/psyclone/tests/psyir/transformations/datanode_extract_trans_test.py b/src/psyclone/tests/psyir/transformations/datanode_to_temp_trans_test.py similarity index 71% rename from src/psyclone/tests/psyir/transformations/datanode_extract_trans_test.py rename to src/psyclone/tests/psyir/transformations/datanode_to_temp_trans_test.py index 33ec4fa37e..8aab48f456 100644 --- a/src/psyclone/tests/psyir/transformations/datanode_extract_trans_test.py +++ b/src/psyclone/tests/psyir/transformations/datanode_to_temp_trans_test.py @@ -33,7 +33,7 @@ # ----------------------------------------------------------------------------- # Authors A. B. G. Chalk STFC Daresbury Lab -'''This module contains the DataNodeExtractTrans class.''' +'''This module contains the DataNodeToTempTrans class.''' import os import pytest @@ -47,13 +47,15 @@ DataSymbol, INTEGER_TYPE ) from psyclone.psyir.transformations import ( - DataNodeExtractTrans, TransformationError + DataNodeToTempTrans, TransformationError ) +from psyclone.tests.utilities import Compile -def test_datanodeextracttrans_validate(fortran_reader, tmpdir, monkeypatch): - """Tests the validate function of the DataNodeExtractTrans.""" - dtrans = DataNodeExtractTrans() +def test_datanodetotemptrans_validate(fortran_reader, tmp_path): + """Tests the non-import related functionality of the validate + function of the DataNodeToTempTrans.""" + dtrans = DataNodeToTempTrans() code = """subroutine test(a, b, c) integer, dimension(:,:), intent(inout) :: a, b, c c = b + a @@ -63,7 +65,7 @@ def test_datanodeextracttrans_validate(fortran_reader, tmpdir, monkeypatch): with pytest.raises(TransformationError) as err: dtrans.validate(assign.rhs) assert ("Input node's datatype is an array of unknown size, so the " - "DataNodeExtractTrans cannot be applied. Input node was " + "DataNodeToTempTrans cannot be applied. Input node was " "'b + a'" in str(err.value)) code = """subroutine test @@ -75,8 +77,11 @@ def test_datanodeextracttrans_validate(fortran_reader, tmpdir, monkeypatch): with pytest.raises(TransformationError) as err: dtrans.validate(assign.rhs) assert ("Input node's datatype cannot be computed, so the " - "DataNodeExtractTrans cannot be applied. Input node " - "was 'b + a'" in str(err.value)) + "DataNodeToTempTrans cannot be applied. Input node " + "was 'b + a'. The following symbols in the input " + "node are not resolved in the scope: '['a', 'b']'. " + "Setting RESOLVE_IMPORTS in the transformation script " + "may enable resolution of these symbols." in str(err.value)) code = """subroutine test character(len=25) :: a, b @@ -88,24 +93,24 @@ def test_datanodeextracttrans_validate(fortran_reader, tmpdir, monkeypatch): with pytest.raises(TransformationError) as err: dtrans.validate(assign.rhs) assert ("Input node's datatype cannot be computed, so the " - "DataNodeExtractTrans cannot be applied. Input node " + "DataNodeToTempTrans cannot be applied. Input node " "was 'a'" in str(err.value)) with pytest.raises(TypeError) as err: dtrans.validate("abc") - assert ("Input node to DataNodeExtractTrans should be a " + assert ("Input node to DataNodeToTempTrans should be a " "DataNode but got 'str'." in str(err.value)) with pytest.raises(TypeError) as err: dtrans.validate(assign.rhs, storage_name=1) - assert ("'DataNodeExtractTrans' received options with the wrong types:\n" + assert ("'DataNodeToTempTrans' received options with the wrong types:\n" "'storage_name' option expects type 'str' but received '1' of " "type 'int'.\nPlease see the documentation and check the " "provided types." in str(err.value)) with pytest.raises(TransformationError) as err: dtrans.validate(Reference(DataSymbol("a", INTEGER_TYPE))) - assert ("Input node to DataNodeExtractTrans has no ancestor Statement " + assert ("Input node to DataNodeToTempTrans has no ancestor Statement " "node which is not supported." in str(err.value)) code = """module some_mod @@ -125,12 +130,20 @@ def test_datanodeextracttrans_validate(fortran_reader, tmpdir, monkeypatch): assign = psyir.walk(Assignment)[2] with pytest.raises(TransformationError) as err: dtrans.validate(assign.rhs) - assert ("Input node to DataNodeExtractTrans contains a call that is not " + assert ("Input node to DataNodeToTempTrans contains a call " + "'some_func(a, b)' that is not " "guaranteed to be pure. Input node is 'a + some_func(a, b)'." in str(err.value)) - monkeypatch.setattr(Config.get(), '_include_paths', [str(tmpdir)]) - filename = os.path.join(str(tmpdir), "a_mod.f90") + +def test_datanodetotemptrans_validate_imports( + fortran_reader, tmp_path, monkeypatch +): + """Tests the import related functionality of the validate + function of the DataNodeToTempTrans.""" + dtrans = DataNodeToTempTrans() + monkeypatch.setattr(Config.get(), '_include_paths', [str(tmp_path)]) + filename = os.path.join(str(tmp_path), "a_mod.f90") with open(filename, "w", encoding='UTF-8') as module: module.write(''' module a_mod @@ -149,13 +162,33 @@ def test_datanodeextracttrans_validate(fortran_reader, tmpdir, monkeypatch): assign = psyir.walk(Assignment)[0] with pytest.raises(TransformationError) as err: dtrans.validate(assign.rhs) - assert ("Input node contains an imported symbol whose name collides " - "with an existing symbol, so the DataNodeExtractTrans cannot be " - "applied. Clashing symbol name is 'i'." in str(err.value)) + assert ("The type of the node supplied to DataNodeToTempTrans depends " + "upon an imported symbol 'i' which has a name clash with a " + "symbol in the current scope." in str(err.value)) + + # This should work if the i in scope is imported from the + # some_mod already. + filename = os.path.join(str(tmp_path), "some_mod.f90") + with open(filename, "w", encoding='UTF-8') as module: + module.write(''' + module some_mod + integer, parameter :: i = 50 + end module some_mod + ''') + code = """subroutine test() + use some_mod, only: i + use a_mod, only: some_var + integer, dimension(25, 50) :: b + b = some_var + end subroutine test""" + psyir = fortran_reader.psyir_from_source(code) + psyir.children[0].symbol_table.resolve_imports() + assign = psyir.walk(Assignment)[0] + dtrans.validate(assign.rhs) # Check validation works when the shape contains a symbol from an # existing module - filename = os.path.join(str(tmpdir), "a_mod.f90") + filename = os.path.join(str(tmp_path), "a_mod.f90") with open(filename, "w", encoding='UTF-8') as module: module.write(''' module a_mod @@ -163,17 +196,17 @@ def test_datanodeextracttrans_validate(fortran_reader, tmpdir, monkeypatch): integer, dimension(25, i) :: some_var end module a_mod ''') - filename = os.path.join(str(tmpdir), "some_mod.f90") + filename = os.path.join(str(tmp_path), "some_mod2.f90") with open(filename, "w", encoding='UTF-8') as module: module.write(''' - module some_mod + module some_mod2 integer, parameter :: i = 25 - integer, parameter :: j = 30 - end module some_mod + integer :: j + end module some_mod2 ''') code = """subroutine test() use a_mod, only: some_var - use some_mod, only: j + use some_mod2, only: j j = some_var(1,3) end subroutine test""" # We need to resolve the module in the Frontend to avoid some_Var @@ -184,7 +217,7 @@ def test_datanodeextracttrans_validate(fortran_reader, tmpdir, monkeypatch): # Check validation raise an error when the shape contains a symbol from # a module that overlaps with a symbol in the scope. - filename = os.path.join(str(tmpdir), "tmpmod.f90") + filename = os.path.join(str(tmp_path), "tmpmod.f90") with open(filename, "w", encoding='UTF-8') as module: module.write(''' module tmpmod @@ -192,7 +225,7 @@ def test_datanodeextracttrans_validate(fortran_reader, tmpdir, monkeypatch): integer, parameter :: j = 30 end module tmpmod ''') - filename = os.path.join(str(tmpdir), "f_mod.f90") + filename = os.path.join(str(tmp_path), "f_mod.f90") with open(filename, "w", encoding='UTF-8') as module: module.write(''' module f_mod @@ -209,15 +242,40 @@ def test_datanodeextracttrans_validate(fortran_reader, tmpdir, monkeypatch): assign = psyir.walk(Assignment)[0] with pytest.raises(TransformationError) as err: dtrans.validate(assign.rhs) - assert ("Input node contains an imported symbol whose containing module " - "collides with an existing symbol. Colliding name is 'tmpmod'." + assert ("Input node contains an imported symbol 'i' whose containing " + "module collides with an existing symbol. Colliding name is " + "'tmpmod'." in str(err.value)) + filename = os.path.join(str(tmp_path), "some_other_mod.f90") + with open(filename, "w", encoding='UTF-8') as module: + module.write(''' + module some_other_mod + integer, parameter :: dim1 = 4, dim2 = 5 + real(kind=wp), dimension(dim1, dim2) :: a_var + public :: a_var + private +end module''') + code = """subroutine test() + use some_other_mod, only: a_var + integer, dimension(4, 5) :: b + b = a_var +end subroutine test + """ + psyir = FortranReader(resolve_modules=True).psyir_from_source(code) + assign = psyir.walk(Assignment)[0] + with pytest.raises(TransformationError) as err: + dtrans.validate(assign.rhs) + assert ("The datatype of the node suppled to DataNodeToTempTrans depends " + "upon an imported symbol 'dim1' that is declared as private in " + "its containing module, so cannot be imported." in str(err.value)) + -def test_datanodeextractrans_apply(fortran_reader, fortran_writer, tmpdir, +def test_datanodetotemptrans_apply(fortran_reader, fortran_writer, tmp_path, monkeypatch): - """Tests the apply function of the DataNodeExtractTrans.""" - dtrans = DataNodeExtractTrans() + """Tests the apply function of the DataNodeToTempTrans without imported + symbols.""" + dtrans = DataNodeToTempTrans() code = """subroutine test() integer, dimension(10,100) :: a integer, dimension(100,10) :: b @@ -232,6 +290,7 @@ def test_datanodeextractrans_apply(fortran_reader, fortran_writer, tmpdir, in out) assert "tmp = MATMUL(a, b)" in out assert "d = c + tmp" in out + assert Compile(tmp_path).string_compiles(out) code = """subroutine test() real :: a @@ -246,6 +305,7 @@ def test_datanodeextractrans_apply(fortran_reader, fortran_writer, tmpdir, assert "integer :: temporary" in out assert "temporary = INT(a)" in out assert "b = temporary" in out + assert Compile(tmp_path).string_compiles(out) code = """subroutine test() real, dimension(100) :: b @@ -279,10 +339,18 @@ def test_datanodeextractrans_apply(fortran_reader, fortran_writer, tmpdir, assert " integer, dimension(3) :: tmp" in out assert """ tmp = 3 * b a(:4) = tmp""" in out + assert Compile(tmp_path).string_compiles(out) + +def test_datanodetotemptrans_apply_imports( + fortran_reader, fortran_writer, tmp_path, monkeypatch +): + """Tests the apply function of the DataNodeToTempTrans with imported + symbols.""" + dtrans = DataNodeToTempTrans() # Test the imports are handled correctly. - monkeypatch.setattr(Config.get(), '_include_paths', [str(tmpdir)]) - filename = os.path.join(str(tmpdir), "a_mod.f90") + monkeypatch.setattr(Config.get(), '_include_paths', [str(tmp_path)]) + filename = os.path.join(str(tmp_path), "a_mod.f90") with open(filename, "w", encoding='UTF-8') as module: module.write(''' module a_mod @@ -303,8 +371,9 @@ def test_datanodeextractrans_apply(fortran_reader, fortran_writer, tmpdir, tmp = some_var b = tmp""" in out + assert Compile(tmp_path).string_compiles(out) - filename = os.path.join(str(tmpdir), "b_mod.f90") + filename = os.path.join(str(tmp_path), "b_mod.f90") with open(filename, "w", encoding='UTF-8') as module: module.write(''' module b_mod @@ -325,8 +394,9 @@ def test_datanodeextractrans_apply(fortran_reader, fortran_writer, tmpdir, tmp = some_var b = tmp""" in out + assert Compile(tmp_path).string_compiles(out) - filename = os.path.join(str(tmpdir), "c_mod.f90") + filename = os.path.join(str(tmp_path), "c_mod.f90") with open(filename, "w", encoding='UTF-8') as module: module.write(''' module c_mod @@ -350,11 +420,12 @@ def test_datanodeextractrans_apply(fortran_reader, fortran_writer, tmpdir, tmp = some_var b = tmp""" in out + assert Compile(tmp_path).string_compiles(out) # Check that modules in a shape from an imported module are # correctly added to the output if the module is already # present as a Container. - filename = os.path.join(str(tmpdir), "f_mod.f90") + filename = os.path.join(str(tmp_path), "f_mod.f90") with open(filename, "w", encoding='UTF-8') as module: module.write(''' module f_mod @@ -362,7 +433,7 @@ def test_datanodeextractrans_apply(fortran_reader, fortran_writer, tmpdir, integer, dimension(25, i) :: some_var end module f_mod ''') - filename = os.path.join(str(tmpdir), "g_mod.f90") + filename = os.path.join(str(tmp_path), "g_mod.f90") with open(filename, "w", encoding='UTF-8') as module: module.write(''' module g_mod @@ -387,3 +458,4 @@ def test_datanodeextractrans_apply(fortran_reader, fortran_writer, tmpdir, tmp = some_var j = tmp""" in out + assert Compile(tmp_path).string_compiles(out) From 801ebd25e744b83cfa72b35ba5184c942e29bd82 Mon Sep 17 00:00:00 2001 From: LonelyCat124 <3043914+LonelyCat124@users.noreply.github.com.> Date: Fri, 6 Feb 2026 15:44:00 +0000 Subject: [PATCH 10/10] Add the transformation in the User guide transformation list --- doc/user_guide/transformations.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/doc/user_guide/transformations.rst b/doc/user_guide/transformations.rst index a505b84424..23b711f893 100644 --- a/doc/user_guide/transformations.rst +++ b/doc/user_guide/transformations.rst @@ -529,6 +529,12 @@ can be found in the API-specific sections). #### +.. autoclass:: psyclone.psyir.transformations.DataNodeToTempTrans + :members: apply + :no-index: + +#### + Algorithm-layer ---------------