Skip to content
3 changes: 3 additions & 0 deletions src/psyclone/psyir/transformations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@
from psyclone.psyir.transformations.omp_parallel_trans import (
OMPParallelTrans,
)
from psyclone.psyir.transformations.datanode_extract_trans import (
DataNodeExtractTrans
)

# For AutoAPI documentation generation
__all__ = [
Expand Down
263 changes: 263 additions & 0 deletions src/psyclone/psyir/transformations/datanode_extract_trans.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
# -----------------------------------------------------------------------------
# BSD 3-Clause License
#
# Copyright (c) 2026, Science and Technology Facilities Council.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
# -----------------------------------------------------------------------------
# Authors A. B. G. Chalk STFC Daresbury Lab

'''This module contains the DataNodeExtractTrans class.'''

from psyclone.psyGen import Transformation
from psyclone.psyir.transformations import TransformationError
from psyclone.psyir.nodes import (
DataNode,
Reference,
Assignment,
Statement,
Call
)
from psyclone.psyir.symbols.datatypes import (
ArrayType,
UnresolvedType,
UnsupportedFortranType,
)
from psyclone.psyir.symbols import (
DataSymbol, ImportInterface, ContainerSymbol)
from psyclone.utils import transformation_documentation_wrapper


@transformation_documentation_wrapper
class DataNodeExtractTrans(Transformation):
"""Provides a generic transformation for moving a datanode from a
statement into a new standalone statement. For example:

>>> from psyclone.psyir.frontend.fortran import FortranReader
>>> from psyclone.psyir.backend.fortran import FortranWriter
>>> from psyclone.psyir.nodes import Assignment
>>> from psyclone.psyir.transformations import DataNodeExtractTrans
>>>
>>> psyir = FortranReader().psyir_from_source('''
... subroutine my_subroutine()
... integer :: i
... integer :: j
... i = j * 2
... end subroutine
... ''')
>>> assign = psyir.walk(Assignment)[0]
>>> DataNodeExtractTrans().apply(assign.rhs, storage_name="temp")
>>> print(FortranWriter()(psyir))
subroutine my_subroutine()
integer, dimension(10,10) :: a
integer :: i
integer :: j
integer :: temp
<BLANKLINE>
temp = j * 2
i = temp
<BLANKLINE>
end subroutine my_subroutine
<BLANKLINE>
"""

def validate(self, node: DataNode, **kwargs):
"""Validity checks for input arguments

:param node: The DataNode to be extracted.

:raises TypeError: if the input arguments are the wrong types.
:raises TransformationError: if the input node's datatype can't be
resolved.
:raises TransformationError: if the input node's datatype is an array
but any of the array's dimensions are unknown.
:raises TransformationError: if the input node doesn't have an
ancestor statement.
:raises TransformationError: if the input node contains a call
that isn't guaranteed to be pure.
"""
# Validate the input options and types.
self.validate_options(**kwargs)

if not isinstance(node, DataNode):
raise TypeError(
f"Input node to DataNodeExtractTrans should be a "
f"DataNode but got '{type(node).__name__}'."
)

dtype = node.datatype

calls = node.walk(Call)
for call in calls:
if not call.is_pure:
raise TransformationError(
f"Input node to DataNodeExtractTrans contains a call "
f"that is not guaranteed to be pure. Input node is "
f"'{node.debug_string().strip()}'."
)
if isinstance(dtype, ArrayType):
for element in dtype.shape:
if element in [ArrayType.Extent.DEFERRED,
ArrayType.Extent.ATTRIBUTE]:
raise TransformationError(
f"Input node's datatype is an array of unknown size, "
f"so the DataNodeExtractTrans cannot be applied. "
f"Input node was '{node.debug_string().strip()}'."
)
# Otherwise we have an ArrayBounds
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need a bit more help following this code. Perhaps, "...ArrayBounds - check to see which Symbols are referenced in those bounds."

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

symbols = set()
if isinstance(element.lower, DataNode):
symbols.update(element.lower.get_all_accessed_symbols())
if isinstance(element.upper, DataNode):
symbols.update(element.upper.get_all_accessed_symbols())
scope_symbols = node.scope.symbol_table.get_symbols()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Compare these symbols with those already in scope at node"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

for sym in symbols:
scoped_name_sym = scope_symbols.get(sym.name, None)
if scoped_name_sym and sym is not scoped_name_sym:
raise TransformationError(
f"Input node contains an imported symbol whose "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It took me a while to work out what this check was for. Could you extend the text to say that the "The type of the node supplied to {self.name} depends upon an imported symbol ('{sym.name}') which has a name clash with a symbol in the current scope."

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

f"name collides with an existing symbol, so the "
f"DataNodeExtractTrans cannot be applied. "
f"Clashing symbol name is '{sym.name}'."
)
# If its an imported symbol we need to check if its
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we want this check as part of the previous one as we may have the same symbol imported from the same location but represented by different Symbol instances. i.e. sym is not scoped_name_sym but, if they are both imported from the same module, they must be representing the same variable.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@arporter I did wonder about this initially, but couldn't find a way to create a test where it happened, so abandoned it, but I'll have another go since you also think it should be possible.

# the same import interface.
if isinstance(sym.interface, ImportInterface):
scoped_name_sym = scope_symbols.get(
sym.interface.container_symbol.name,
None
)
if scoped_name_sym and not isinstance(
scoped_name_sym, ContainerSymbol):
raise TransformationError(
f"Input node contains an imported symbol "
f"whose containing module collides with an "
f"existing symbol. Colliding name is "
f"'{sym.interface.container_symbol.name}'."
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately we also have to allow for a situation when the type of the imported variable (a_var in this case) depends on private variables:

module some_other_mod
    integer, parameter :: dim1 = 4, dim2 = 5
    real(kind=wp), dimension(dim1, dim2) :: a_var
    public :: a_var
    private
end module

in which case we can't import the symbols that define the shape of the array. You'll need to check their visibility in the symbol_table associated with the Container pointed to by the import interface of a_var.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@arporter I'm a bit confused when trying to do this test - when I examine the interface of dim1 and dim2 (as found in the size of the a_var.datatype)I don't get an ImportInterface but instead a StaticInterface. Is this expected behaviour? If so, how can I differentiate between two private static variables where one is in the current scope and one is from a "resolve imports" module?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I'm not sure how to tell the difference between the dim/dim2 symbols in your example, and x in this code:

  implicit none
  integer, parameter, private :: x = 1
  private

  contains
  subroutine sub()
    integer :: i

    i = x

  end subroutine sub

end module test

I could check that if its a static variable then it has to be in the scope? But why are those not ImportInterface symbols despite being static?


if node.ancestor(Statement) is None:
raise TransformationError(
"Input node to DataNodeExtractTrans has no ancestor "
"Statement node which is not supported."
)

if isinstance(dtype, (UnresolvedType, UnsupportedFortranType)):
raise TransformationError(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit of a faff but I think it would be very useful to the user if you could examine the symbols referenced in node and see if you can establish which of them has unsupported/unresolved type. It may be that they could be resolved by setting RESOLVE_IMPORTS appropriately in the transformation script.

f"Input node's datatype cannot be computed, so the "
f"DataNodeExtractTrans cannot be applied. Input node was "
f"'{node.debug_string().strip()}'."
)

def apply(self, node: DataNode, storage_name: str = "", **kwargs):
"""Applies the DataNodeExtractTrans to the input arguments.

:param node: The datanode to extract.
:param storage_name: The name of the temporary variable to store
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps "The (base) name of the..." to indicate that it won't necessarily be exactly that name that is used. Also, please could you de-dent the following lines to just use four spaces so we don't need as many lines.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we need to decide if this should be a base name of if the storage_name argument is provided the name must be exact? Happy to go either way - if the name must be exact I should probably add it to validate.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't imagine anyone ever caring that it was exact so I vote for it just being a suggestion (i.e. base_name).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed this behaviour now.

the result of the input node in. The default is tmp(_...)
based on the rules defined in the SymbolTable class.
"""
# Call validate to check inputs are valid.
self.validate(node, storage_name=storage_name, **kwargs)

# Find the datatype
datatype = node.datatype

# Create a symbol of the relevant type.
if not storage_name:
symbol = node.scope.symbol_table.new_symbol(
root_name="tmp",
symbol_type=DataSymbol,
datatype=datatype
)
else:
symbol = node.scope.symbol_table.new_symbol(
root_name=storage_name,
symbol_type=DataSymbol,
allow_renaming=False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How come you've set this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking the user-defined name should be honoured exactly, but I'm happy to relax that if you think that would be better.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed, I think this can be relaxed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

datatype=datatype
)

# FIXME Make sure the shape is all in the symbol table. We know that
# all symbols we find can be safely added as otherwise validate will
# fail.
# This is an oversimplification because we could have multiple
# references to the same symbol...
if isinstance(datatype, ArrayType):
for element in datatype.shape:
symbols = set()
if isinstance(element.lower, DataNode):
symbols.update(element.lower.get_all_accessed_symbols())
if isinstance(element.upper, DataNode):
symbols.update(element.upper.get_all_accessed_symbols())
scope_symbols = node.scope.symbol_table.get_symbols()
for sym in symbols:
scoped_name_sym = scope_symbols.get(sym.name, None)
# If no symbol with the name exists then create one.
if not scoped_name_sym:
sym_copy = sym.copy()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hadn't understood what you meant by "import symbols" before but now I do. I guess the premise here is that if a symbol appears in the definition of an imported symbol's shape, then it must be a compile-time constant? (Because if it was allocatable, we wouldn't have sizes for its dimensions.) In which case I think it's OK. Presumably also, we're only adding these symbols into the table associated with a Routine? (I started worrying that they might have module scope, in which case we would need to make sure they were private.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth summarising this in a comment I think - possibly instead of the FIXME above?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not quite, if we have a symbol (that needs to be imported from another module and currently isn't) in the shape then we add it to the symbol table with an ImportInterface, so the output should be use the_module, only: the_shape_symbol. Updating the comment to explain this I think.

if isinstance(sym_copy.interface, ImportInterface):
# Check if the ContainerSymbol is already in the
# interface
container = scope_symbols.get(
sym_copy.interface.container_symbol.name,
None
)
if container is None:
# Add the container symbol the the symbol table
# and we're ok with this symbol.
node.scope.symbol_table.add(
sym_copy.interface.container_symbol
)
# If we find the container then we need to update
# the interface to use the container listed.
else:
sym_copy.interface.container_symbol = \
container
node.scope.symbol_table.add(sym_copy)

# Create a Reference to the new symbol
new_ref = Reference(symbol)

# Find the parent and position of the statement containing the
# DataNode.
parent = node.ancestor(Statement).parent
pos = node.ancestor(Statement).position

# Replace the datanode with the new reference
node.replace_with(new_ref)

# Create an assignment to set the value of the new symbol
assign = Assignment.create(new_ref.copy(), node)

# Add the assignment into the tree.
parent.addchild(assign, pos)


__all__ = ["DataNodeExtractTrans"]
Loading
Loading