diff --git a/src/psyclone/psyir/transformations/__init__.py b/src/psyclone/psyir/transformations/__init__.py index 9f186ee31e..ebe8208112 100644 --- a/src/psyclone/psyir/transformations/__init__.py +++ b/src/psyclone/psyir/transformations/__init__.py @@ -129,6 +129,9 @@ from psyclone.psyir.transformations.omp_parallel_trans import ( OMPParallelTrans, ) +from psyclone.psyir.transformations.datanode_extract_trans import ( + DataNodeExtractTrans +) # For AutoAPI documentation generation __all__ = [ diff --git a/src/psyclone/psyir/transformations/datanode_extract_trans.py b/src/psyclone/psyir/transformations/datanode_extract_trans.py new file mode 100644 index 0000000000..58f6ae0e26 --- /dev/null +++ b/src/psyclone/psyir/transformations/datanode_extract_trans.py @@ -0,0 +1,263 @@ +# ----------------------------------------------------------------------------- +# BSD 3-Clause License +# +# Copyright (c) 2026, Science and Technology Facilities Council. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS +# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE +# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN +# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# ----------------------------------------------------------------------------- +# Authors A. B. G. Chalk STFC Daresbury Lab + +'''This module contains the DataNodeExtractTrans class.''' + +from psyclone.psyGen import Transformation +from psyclone.psyir.transformations import TransformationError +from psyclone.psyir.nodes import ( + DataNode, + Reference, + Assignment, + Statement, + Call +) +from psyclone.psyir.symbols.datatypes import ( + ArrayType, + UnresolvedType, + UnsupportedFortranType, +) +from psyclone.psyir.symbols import ( + DataSymbol, ImportInterface, ContainerSymbol) +from psyclone.utils import transformation_documentation_wrapper + + +@transformation_documentation_wrapper +class DataNodeExtractTrans(Transformation): + """Provides a generic transformation for moving a datanode from a + statement into a new standalone statement. For example: + + >>> from psyclone.psyir.frontend.fortran import FortranReader + >>> from psyclone.psyir.backend.fortran import FortranWriter + >>> from psyclone.psyir.nodes import Assignment + >>> from psyclone.psyir.transformations import DataNodeExtractTrans + >>> + >>> psyir = FortranReader().psyir_from_source(''' + ... subroutine my_subroutine() + ... integer :: i + ... integer :: j + ... i = j * 2 + ... end subroutine + ... ''') + >>> assign = psyir.walk(Assignment)[0] + >>> DataNodeExtractTrans().apply(assign.rhs, storage_name="temp") + >>> print(FortranWriter()(psyir)) + subroutine my_subroutine() + integer, dimension(10,10) :: a + integer :: i + integer :: j + integer :: temp + + temp = j * 2 + i = temp + + end subroutine my_subroutine + + """ + + def validate(self, node: DataNode, **kwargs): + """Validity checks for input arguments + + :param node: The DataNode to be extracted. + + :raises TypeError: if the input arguments are the wrong types. + :raises TransformationError: if the input node's datatype can't be + resolved. + :raises TransformationError: if the input node's datatype is an array + but any of the array's dimensions are unknown. + :raises TransformationError: if the input node doesn't have an + ancestor statement. + :raises TransformationError: if the input node contains a call + that isn't guaranteed to be pure. + """ + # Validate the input options and types. + self.validate_options(**kwargs) + + if not isinstance(node, DataNode): + raise TypeError( + f"Input node to DataNodeExtractTrans should be a " + f"DataNode but got '{type(node).__name__}'." + ) + + dtype = node.datatype + + calls = node.walk(Call) + for call in calls: + if not call.is_pure: + raise TransformationError( + f"Input node to DataNodeExtractTrans contains a call " + f"that is not guaranteed to be pure. Input node is " + f"'{node.debug_string().strip()}'." + ) + if isinstance(dtype, ArrayType): + for element in dtype.shape: + if element in [ArrayType.Extent.DEFERRED, + ArrayType.Extent.ATTRIBUTE]: + raise TransformationError( + f"Input node's datatype is an array of unknown size, " + f"so the DataNodeExtractTrans cannot be applied. " + f"Input node was '{node.debug_string().strip()}'." + ) + # Otherwise we have an ArrayBounds + symbols = set() + if isinstance(element.lower, DataNode): + symbols.update(element.lower.get_all_accessed_symbols()) + if isinstance(element.upper, DataNode): + symbols.update(element.upper.get_all_accessed_symbols()) + scope_symbols = node.scope.symbol_table.get_symbols() + for sym in symbols: + scoped_name_sym = scope_symbols.get(sym.name, None) + if scoped_name_sym and sym is not scoped_name_sym: + raise TransformationError( + f"Input node contains an imported symbol whose " + f"name collides with an existing symbol, so the " + f"DataNodeExtractTrans cannot be applied. " + f"Clashing symbol name is '{sym.name}'." + ) + # If its an imported symbol we need to check if its + # the same import interface. + if isinstance(sym.interface, ImportInterface): + scoped_name_sym = scope_symbols.get( + sym.interface.container_symbol.name, + None + ) + if scoped_name_sym and not isinstance( + scoped_name_sym, ContainerSymbol): + raise TransformationError( + f"Input node contains an imported symbol " + f"whose containing module collides with an " + f"existing symbol. Colliding name is " + f"'{sym.interface.container_symbol.name}'." + ) + + if node.ancestor(Statement) is None: + raise TransformationError( + "Input node to DataNodeExtractTrans has no ancestor " + "Statement node which is not supported." + ) + + if isinstance(dtype, (UnresolvedType, UnsupportedFortranType)): + raise TransformationError( + f"Input node's datatype cannot be computed, so the " + f"DataNodeExtractTrans cannot be applied. Input node was " + f"'{node.debug_string().strip()}'." + ) + + def apply(self, node: DataNode, storage_name: str = "", **kwargs): + """Applies the DataNodeExtractTrans to the input arguments. + + :param node: The datanode to extract. + :param storage_name: The name of the temporary variable to store + the result of the input node in. The default is tmp(_...) + based on the rules defined in the SymbolTable class. + """ + # Call validate to check inputs are valid. + self.validate(node, storage_name=storage_name, **kwargs) + + # Find the datatype + datatype = node.datatype + + # Create a symbol of the relevant type. + if not storage_name: + symbol = node.scope.symbol_table.new_symbol( + root_name="tmp", + symbol_type=DataSymbol, + datatype=datatype + ) + else: + symbol = node.scope.symbol_table.new_symbol( + root_name=storage_name, + symbol_type=DataSymbol, + allow_renaming=False, + datatype=datatype + ) + + # FIXME Make sure the shape is all in the symbol table. We know that + # all symbols we find can be safely added as otherwise validate will + # fail. + # This is an oversimplification because we could have multiple + # references to the same symbol... + if isinstance(datatype, ArrayType): + for element in datatype.shape: + symbols = set() + if isinstance(element.lower, DataNode): + symbols.update(element.lower.get_all_accessed_symbols()) + if isinstance(element.upper, DataNode): + symbols.update(element.upper.get_all_accessed_symbols()) + scope_symbols = node.scope.symbol_table.get_symbols() + for sym in symbols: + scoped_name_sym = scope_symbols.get(sym.name, None) + # If no symbol with the name exists then create one. + if not scoped_name_sym: + sym_copy = sym.copy() + if isinstance(sym_copy.interface, ImportInterface): + # Check if the ContainerSymbol is already in the + # interface + container = scope_symbols.get( + sym_copy.interface.container_symbol.name, + None + ) + if container is None: + # Add the container symbol the the symbol table + # and we're ok with this symbol. + node.scope.symbol_table.add( + sym_copy.interface.container_symbol + ) + # If we find the container then we need to update + # the interface to use the container listed. + else: + sym_copy.interface.container_symbol = \ + container + node.scope.symbol_table.add(sym_copy) + + # Create a Reference to the new symbol + new_ref = Reference(symbol) + + # Find the parent and position of the statement containing the + # DataNode. + parent = node.ancestor(Statement).parent + pos = node.ancestor(Statement).position + + # Replace the datanode with the new reference + node.replace_with(new_ref) + + # Create an assignment to set the value of the new symbol + assign = Assignment.create(new_ref.copy(), node) + + # Add the assignment into the tree. + parent.addchild(assign, pos) + + +__all__ = ["DataNodeExtractTrans"] diff --git a/src/psyclone/tests/psyir/transformations/datanode_extract_trans_test.py b/src/psyclone/tests/psyir/transformations/datanode_extract_trans_test.py new file mode 100644 index 0000000000..33ec4fa37e --- /dev/null +++ b/src/psyclone/tests/psyir/transformations/datanode_extract_trans_test.py @@ -0,0 +1,389 @@ +# ----------------------------------------------------------------------------- +# BSD 3-Clause License +# +# Copyright (c) 2026, Science and Technology Facilities Council. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS +# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE +# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN +# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# ----------------------------------------------------------------------------- +# Authors A. B. G. Chalk STFC Daresbury Lab + +'''This module contains the DataNodeExtractTrans class.''' + +import os +import pytest + +from psyclone.configuration import Config +from psyclone.psyir.frontend.fortran import FortranReader +from psyclone.psyir.nodes import ( + Assignment, Reference +) +from psyclone.psyir.symbols import ( + DataSymbol, INTEGER_TYPE +) +from psyclone.psyir.transformations import ( + DataNodeExtractTrans, TransformationError +) + + +def test_datanodeextracttrans_validate(fortran_reader, tmpdir, monkeypatch): + """Tests the validate function of the DataNodeExtractTrans.""" + dtrans = DataNodeExtractTrans() + code = """subroutine test(a, b, c) + integer, dimension(:,:), intent(inout) :: a, b, c + c = b + a + end subroutine test""" + psyir = fortran_reader.psyir_from_source(code) + assign = psyir.walk(Assignment)[0] + with pytest.raises(TransformationError) as err: + dtrans.validate(assign.rhs) + assert ("Input node's datatype is an array of unknown size, so the " + "DataNodeExtractTrans cannot be applied. Input node was " + "'b + a'" in str(err.value)) + + code = """subroutine test + use some_mod + c = b + a + end subroutine test""" + psyir = fortran_reader.psyir_from_source(code) + assign = psyir.walk(Assignment)[0] + with pytest.raises(TransformationError) as err: + dtrans.validate(assign.rhs) + assert ("Input node's datatype cannot be computed, so the " + "DataNodeExtractTrans cannot be applied. Input node " + "was 'b + a'" in str(err.value)) + + code = """subroutine test + character(len=25) :: a, b + + b = a + end subroutine test""" + psyir = fortran_reader.psyir_from_source(code) + assign = psyir.walk(Assignment)[0] + with pytest.raises(TransformationError) as err: + dtrans.validate(assign.rhs) + assert ("Input node's datatype cannot be computed, so the " + "DataNodeExtractTrans cannot be applied. Input node " + "was 'a'" in str(err.value)) + + with pytest.raises(TypeError) as err: + dtrans.validate("abc") + assert ("Input node to DataNodeExtractTrans should be a " + "DataNode but got 'str'." in str(err.value)) + + with pytest.raises(TypeError) as err: + dtrans.validate(assign.rhs, storage_name=1) + assert ("'DataNodeExtractTrans' received options with the wrong types:\n" + "'storage_name' option expects type 'str' but received '1' of " + "type 'int'.\nPlease see the documentation and check the " + "provided types." in str(err.value)) + + with pytest.raises(TransformationError) as err: + dtrans.validate(Reference(DataSymbol("a", INTEGER_TYPE))) + assert ("Input node to DataNodeExtractTrans has no ancestor Statement " + "node which is not supported." in str(err.value)) + + code = """module some_mod + contains + integer function some_func(a, b) + integer :: a, b + a = a + b + some_func = a + b + end function + subroutine test() + integer :: a, b + + a = a + some_func(a,b) + end subroutine test + end module""" + psyir = fortran_reader.psyir_from_source(code) + assign = psyir.walk(Assignment)[2] + with pytest.raises(TransformationError) as err: + dtrans.validate(assign.rhs) + assert ("Input node to DataNodeExtractTrans contains a call that is not " + "guaranteed to be pure. Input node is 'a + some_func(a, b)'." + in str(err.value)) + + monkeypatch.setattr(Config.get(), '_include_paths', [str(tmpdir)]) + filename = os.path.join(str(tmpdir), "a_mod.f90") + with open(filename, "w", encoding='UTF-8') as module: + module.write(''' + module a_mod + use some_mod, only: i + integer, dimension(25, i) :: some_var + end module a_mod + ''') + code = """subroutine test() + use a_mod, only: some_var + integer, dimension(25, 50) :: b + integer :: i + b = some_var + end subroutine test""" + psyir = fortran_reader.psyir_from_source(code) + psyir.children[0].symbol_table.resolve_imports() + assign = psyir.walk(Assignment)[0] + with pytest.raises(TransformationError) as err: + dtrans.validate(assign.rhs) + assert ("Input node contains an imported symbol whose name collides " + "with an existing symbol, so the DataNodeExtractTrans cannot be " + "applied. Clashing symbol name is 'i'." in str(err.value)) + + # Check validation works when the shape contains a symbol from an + # existing module + filename = os.path.join(str(tmpdir), "a_mod.f90") + with open(filename, "w", encoding='UTF-8') as module: + module.write(''' + module a_mod + use some_mod, only: i + integer, dimension(25, i) :: some_var + end module a_mod + ''') + filename = os.path.join(str(tmpdir), "some_mod.f90") + with open(filename, "w", encoding='UTF-8') as module: + module.write(''' + module some_mod + integer, parameter :: i = 25 + integer, parameter :: j = 30 + end module some_mod + ''') + code = """subroutine test() + use a_mod, only: some_var + use some_mod, only: j + j = some_var(1,3) + end subroutine test""" + # We need to resolve the module in the Frontend to avoid some_Var + # becoming a call. + psyir = FortranReader(resolve_modules=True).psyir_from_source(code) + assign = psyir.walk(Assignment)[0] + dtrans.validate(assign.rhs) + + # Check validation raise an error when the shape contains a symbol from + # a module that overlaps with a symbol in the scope. + filename = os.path.join(str(tmpdir), "tmpmod.f90") + with open(filename, "w", encoding='UTF-8') as module: + module.write(''' + module tmpmod + integer, parameter :: i = 25 + integer, parameter :: j = 30 + end module tmpmod + ''') + filename = os.path.join(str(tmpdir), "f_mod.f90") + with open(filename, "w", encoding='UTF-8') as module: + module.write(''' + module f_mod + use tmpmod, only: i + integer, dimension(25, i) :: some_var + end module f_mod + ''') + code = """subroutine test() + use f_mod, only: some_var + integer :: tmpmod + tmpmod = some_var + end subroutine test""" + psyir = FortranReader(resolve_modules=True).psyir_from_source(code) + assign = psyir.walk(Assignment)[0] + with pytest.raises(TransformationError) as err: + dtrans.validate(assign.rhs) + assert ("Input node contains an imported symbol whose containing module " + "collides with an existing symbol. Colliding name is 'tmpmod'." + in str(err.value)) + + +def test_datanodeextractrans_apply(fortran_reader, fortran_writer, tmpdir, + monkeypatch): + """Tests the apply function of the DataNodeExtractTrans.""" + dtrans = DataNodeExtractTrans() + code = """subroutine test() + integer, dimension(10,100) :: a + integer, dimension(100,10) :: b + integer, dimension(10,10) :: c, d + d = c + MATMUL(a, b) + end subroutine test""" + psyir = fortran_reader.psyir_from_source(code) + assign = psyir.walk(Assignment)[0] + dtrans.apply(assign.rhs.operands[1]) + out = fortran_writer(psyir) + assert ("integer, dimension(SIZE(a, dim=1),SIZE(b, dim=2)) :: tmp" + in out) + assert "tmp = MATMUL(a, b)" in out + assert "d = c + tmp" in out + + code = """subroutine test() + real :: a + integer :: b + + b = INT(a) + end subroutine test""" + psyir = fortran_reader.psyir_from_source(code) + assign = psyir.walk(Assignment)[0] + dtrans.apply(assign.rhs, storage_name="temporary") + out = fortran_writer(psyir) + assert "integer :: temporary" in out + assert "temporary = INT(a)" in out + assert "b = temporary" in out + + code = """subroutine test() + real, dimension(100) :: b + integer :: i + + do i = 1, 100 + b(i) = REAL(i) + end do + end subroutine test""" + psyir = fortran_reader.psyir_from_source(code) + assign = psyir.walk(Assignment)[0] + dtrans.apply(assign.rhs, storage_name="temporary") + out = fortran_writer(psyir) + assert " real :: temporary" in out + assert """ do i = 1, 100, 1 + temporary = REAL(i) + b(i) = temporary + enddo""" in out + + code = """subroutine test() + integer, dimension(2:6) :: a + integer, dimension(1:3) :: b + + a(2:4) = 3 * b + + end subroutine test""" + psyir = fortran_reader.psyir_from_source(code) + assign = psyir.walk(Assignment)[0] + dtrans.apply(assign.rhs) + out = fortran_writer(psyir) + assert " integer, dimension(3) :: tmp" in out + assert """ tmp = 3 * b + a(:4) = tmp""" in out + + # Test the imports are handled correctly. + monkeypatch.setattr(Config.get(), '_include_paths', [str(tmpdir)]) + filename = os.path.join(str(tmpdir), "a_mod.f90") + with open(filename, "w", encoding='UTF-8') as module: + module.write(''' + module a_mod + integer :: some_var + end module a_mod + ''') + code = """subroutine test() + use a_mod + integer :: b + b = some_var + end subroutine test""" + psyir = fortran_reader.psyir_from_source(code) + psyir.children[0].symbol_table.resolve_imports() + assign = psyir.walk(Assignment)[0] + dtrans.apply(assign.rhs) + out = fortran_writer(psyir) + assert """ integer :: tmp + + tmp = some_var + b = tmp""" in out + + filename = os.path.join(str(tmpdir), "b_mod.f90") + with open(filename, "w", encoding='UTF-8') as module: + module.write(''' + module b_mod + integer, dimension(25, 50) :: some_var + end module b_mod + ''') + code = """subroutine test() + use b_mod + integer, dimension(25, 50) :: b + b = some_var + end subroutine test""" + psyir = fortran_reader.psyir_from_source(code) + psyir.children[0].symbol_table.resolve_imports() + assign = psyir.walk(Assignment)[0] + dtrans.apply(assign.rhs) + out = fortran_writer(psyir) + assert """ integer, dimension(25,50) :: tmp + + tmp = some_var + b = tmp""" in out + + filename = os.path.join(str(tmpdir), "c_mod.f90") + with open(filename, "w", encoding='UTF-8') as module: + module.write(''' + module c_mod + use some_mod, only: i + integer, dimension(25, i) :: some_var + end module c_mod + ''') + code = """subroutine test() + use c_mod, only: some_var + integer, dimension(25, 50) :: b + b = some_var + end subroutine test""" + psyir = fortran_reader.psyir_from_source(code) + psyir.children[0].symbol_table.resolve_imports() + assign = psyir.walk(Assignment)[0] + dtrans.apply(assign.rhs) + out = fortran_writer(psyir) + assert """ use some_mod, only : i + integer, dimension(25,50) :: b + integer, dimension(25,i) :: tmp + + tmp = some_var + b = tmp""" in out + + # Check that modules in a shape from an imported module are + # correctly added to the output if the module is already + # present as a Container. + filename = os.path.join(str(tmpdir), "f_mod.f90") + with open(filename, "w", encoding='UTF-8') as module: + module.write(''' + module f_mod + use g_mod, only: i + integer, dimension(25, i) :: some_var + end module f_mod + ''') + filename = os.path.join(str(tmpdir), "g_mod.f90") + with open(filename, "w", encoding='UTF-8') as module: + module.write(''' + module g_mod + integer, parameter :: i = 25 + integer, dimension(25, i) :: j = 30 + end module g_mod + ''') + code = """subroutine test() + use g_mod, only: j + use f_mod, only: some_var + j = some_var + end subroutine test""" + # We need to resolve the module in the Frontend to avoid some_Var + # becoming a call. + psyir = FortranReader(resolve_modules=True).psyir_from_source(code) + assign = psyir.walk(Assignment)[0] + dtrans.apply(assign.rhs) + out = fortran_writer(psyir) + assert """ use g_mod, only : i, j + use f_mod, only : some_var + integer, dimension(25,i) :: tmp + + tmp = some_var + j = tmp""" in out