-
Notifications
You must be signed in to change notification settings - Fork 33
DataNodeExtractTrans implementation #3301
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
1adc9ce
d4d7510
8c99aa3
dd30fe9
f81a2c3
9a28c7b
13db8c7
731bb59
ffa39bc
b3a90fc
5c5fea5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,263 @@ | ||
| # ----------------------------------------------------------------------------- | ||
| # BSD 3-Clause License | ||
| # | ||
| # Copyright (c) 2026, Science and Technology Facilities Council. | ||
| # All rights reserved. | ||
| # | ||
| # Redistribution and use in source and binary forms, with or without | ||
| # modification, are permitted provided that the following conditions are met: | ||
| # | ||
| # * Redistributions of source code must retain the above copyright notice, this | ||
| # list of conditions and the following disclaimer. | ||
| # | ||
| # * Redistributions in binary form must reproduce the above copyright notice, | ||
| # this list of conditions and the following disclaimer in the documentation | ||
| # and/or other materials provided with the distribution. | ||
| # | ||
| # * Neither the name of the copyright holder nor the names of its | ||
| # contributors may be used to endorse or promote products derived from | ||
| # this software without specific prior written permission. | ||
| # | ||
| # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS | ||
| # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT | ||
| # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS | ||
| # FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE | ||
| # COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, | ||
| # INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, | ||
| # BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; | ||
| # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | ||
| # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT | ||
| # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN | ||
| # ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE | ||
| # POSSIBILITY OF SUCH DAMAGE. | ||
| # ----------------------------------------------------------------------------- | ||
| # Authors A. B. G. Chalk STFC Daresbury Lab | ||
|
|
||
| '''This module contains the DataNodeExtractTrans class.''' | ||
|
|
||
| from psyclone.psyGen import Transformation | ||
| from psyclone.psyir.transformations import TransformationError | ||
| from psyclone.psyir.nodes import ( | ||
| DataNode, | ||
| Reference, | ||
| Assignment, | ||
| Statement, | ||
| Call | ||
| ) | ||
| from psyclone.psyir.symbols.datatypes import ( | ||
| ArrayType, | ||
| UnresolvedType, | ||
| UnsupportedFortranType, | ||
| ) | ||
| from psyclone.psyir.symbols import ( | ||
| DataSymbol, ImportInterface, ContainerSymbol) | ||
| from psyclone.utils import transformation_documentation_wrapper | ||
|
|
||
|
|
||
| @transformation_documentation_wrapper | ||
| class DataNodeExtractTrans(Transformation): | ||
| """Provides a generic transformation for moving a datanode from a | ||
| statement into a new standalone statement. For example: | ||
|
|
||
| >>> from psyclone.psyir.frontend.fortran import FortranReader | ||
| >>> from psyclone.psyir.backend.fortran import FortranWriter | ||
| >>> from psyclone.psyir.nodes import Assignment | ||
| >>> from psyclone.psyir.transformations import DataNodeExtractTrans | ||
| >>> | ||
| >>> psyir = FortranReader().psyir_from_source(''' | ||
| ... subroutine my_subroutine() | ||
| ... integer :: i | ||
| ... integer :: j | ||
| ... i = j * 2 | ||
| ... end subroutine | ||
| ... ''') | ||
| >>> assign = psyir.walk(Assignment)[0] | ||
| >>> DataNodeExtractTrans().apply(assign.rhs, storage_name="temp") | ||
| >>> print(FortranWriter()(psyir)) | ||
| subroutine my_subroutine() | ||
| integer, dimension(10,10) :: a | ||
| integer :: i | ||
| integer :: j | ||
| integer :: temp | ||
| <BLANKLINE> | ||
| temp = j * 2 | ||
| i = temp | ||
| <BLANKLINE> | ||
| end subroutine my_subroutine | ||
| <BLANKLINE> | ||
| """ | ||
|
|
||
| def validate(self, node: DataNode, **kwargs): | ||
| """Validity checks for input arguments | ||
|
|
||
| :param node: The DataNode to be extracted. | ||
|
|
||
| :raises TypeError: if the input arguments are the wrong types. | ||
| :raises TransformationError: if the input node's datatype can't be | ||
| resolved. | ||
| :raises TransformationError: if the input node's datatype is an array | ||
| but any of the array's dimensions are unknown. | ||
| :raises TransformationError: if the input node doesn't have an | ||
| ancestor statement. | ||
| :raises TransformationError: if the input node contains a call | ||
| that isn't guaranteed to be pure. | ||
| """ | ||
| # Validate the input options and types. | ||
| self.validate_options(**kwargs) | ||
|
|
||
| if not isinstance(node, DataNode): | ||
| raise TypeError( | ||
| f"Input node to DataNodeExtractTrans should be a " | ||
| f"DataNode but got '{type(node).__name__}'." | ||
| ) | ||
|
|
||
| dtype = node.datatype | ||
|
|
||
| calls = node.walk(Call) | ||
| for call in calls: | ||
| if not call.is_pure: | ||
| raise TransformationError( | ||
| f"Input node to DataNodeExtractTrans contains a call " | ||
| f"that is not guaranteed to be pure. Input node is " | ||
| f"'{node.debug_string().strip()}'." | ||
| ) | ||
| if isinstance(dtype, ArrayType): | ||
| for element in dtype.shape: | ||
| if element in [ArrayType.Extent.DEFERRED, | ||
| ArrayType.Extent.ATTRIBUTE]: | ||
| raise TransformationError( | ||
| f"Input node's datatype is an array of unknown size, " | ||
| f"so the DataNodeExtractTrans cannot be applied. " | ||
| f"Input node was '{node.debug_string().strip()}'." | ||
| ) | ||
| # Otherwise we have an ArrayBounds | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I need a bit more help following this code. Perhaps, "...ArrayBounds - check to see which Symbols are referenced in those bounds."
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
| symbols = set() | ||
| if isinstance(element.lower, DataNode): | ||
| symbols.update(element.lower.get_all_accessed_symbols()) | ||
| if isinstance(element.upper, DataNode): | ||
| symbols.update(element.upper.get_all_accessed_symbols()) | ||
| scope_symbols = node.scope.symbol_table.get_symbols() | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "Compare these symbols with those already in scope at
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added. |
||
| for sym in symbols: | ||
| scoped_name_sym = scope_symbols.get(sym.name, None) | ||
| if scoped_name_sym and sym is not scoped_name_sym: | ||
| raise TransformationError( | ||
| f"Input node contains an imported symbol whose " | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It took me a while to work out what this check was for. Could you extend the text to say that the "The type of the node supplied to {self.name} depends upon an imported symbol ('{sym.name}') which has a name clash with a symbol in the current scope."
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
| f"name collides with an existing symbol, so the " | ||
| f"DataNodeExtractTrans cannot be applied. " | ||
| f"Clashing symbol name is '{sym.name}'." | ||
| ) | ||
| # If its an imported symbol we need to check if its | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we want this check as part of the previous one as we may have the same symbol imported from the same location but represented by different Symbol instances. i.e.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @arporter I did wonder about this initially, but couldn't find a way to create a test where it happened, so abandoned it, but I'll have another go since you also think it should be possible. |
||
| # the same import interface. | ||
| if isinstance(sym.interface, ImportInterface): | ||
| scoped_name_sym = scope_symbols.get( | ||
| sym.interface.container_symbol.name, | ||
| None | ||
| ) | ||
| if scoped_name_sym and not isinstance( | ||
| scoped_name_sym, ContainerSymbol): | ||
| raise TransformationError( | ||
| f"Input node contains an imported symbol " | ||
| f"whose containing module collides with an " | ||
| f"existing symbol. Colliding name is " | ||
| f"'{sym.interface.container_symbol.name}'." | ||
| ) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unfortunately we also have to allow for a situation when the type of the imported variable ( module some_other_mod
integer, parameter :: dim1 = 4, dim2 = 5
real(kind=wp), dimension(dim1, dim2) :: a_var
public :: a_var
private
end modulein which case we can't import the symbols that define the shape of the array. You'll need to check their visibility in the symbol_table associated with the Container pointed to by the import interface of
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @arporter I'm a bit confused when trying to do this test - when I examine the interface of
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I'm not sure how to tell the difference between the dim/dim2 symbols in your example, and x in this code: I could check that if its a static variable then it has to be in the scope? But why are those not ImportInterface symbols despite being static? |
||
|
|
||
| if node.ancestor(Statement) is None: | ||
| raise TransformationError( | ||
| "Input node to DataNodeExtractTrans has no ancestor " | ||
| "Statement node which is not supported." | ||
| ) | ||
|
|
||
| if isinstance(dtype, (UnresolvedType, UnsupportedFortranType)): | ||
| raise TransformationError( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a bit of a faff but I think it would be very useful to the user if you could examine the symbols referenced in |
||
| f"Input node's datatype cannot be computed, so the " | ||
| f"DataNodeExtractTrans cannot be applied. Input node was " | ||
| f"'{node.debug_string().strip()}'." | ||
| ) | ||
|
|
||
arporter marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| def apply(self, node: DataNode, storage_name: str = "", **kwargs): | ||
| """Applies the DataNodeExtractTrans to the input arguments. | ||
|
|
||
| :param node: The datanode to extract. | ||
| :param storage_name: The name of the temporary variable to store | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps "The (base) name of the..." to indicate that it won't necessarily be exactly that name that is used. Also, please could you de-dent the following lines to just use four spaces so we don't need as many lines.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess we need to decide if this should be a base name of if the storage_name argument is provided the name must be exact? Happy to go either way - if the name must be exact I should probably add it to validate.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can't imagine anyone ever caring that it was exact so I vote for it just being a suggestion (i.e. base_name).
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changed this behaviour now. |
||
| the result of the input node in. The default is tmp(_...) | ||
| based on the rules defined in the SymbolTable class. | ||
| """ | ||
| # Call validate to check inputs are valid. | ||
| self.validate(node, storage_name=storage_name, **kwargs) | ||
|
|
||
| # Find the datatype | ||
| datatype = node.datatype | ||
|
|
||
| # Create a symbol of the relevant type. | ||
| if not storage_name: | ||
| symbol = node.scope.symbol_table.new_symbol( | ||
arporter marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| root_name="tmp", | ||
| symbol_type=DataSymbol, | ||
| datatype=datatype | ||
| ) | ||
| else: | ||
| symbol = node.scope.symbol_table.new_symbol( | ||
| root_name=storage_name, | ||
| symbol_type=DataSymbol, | ||
| allow_renaming=False, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How come you've set this?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was thinking the user-defined name should be honoured exactly, but I'm happy to relax that if you think that would be better.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As discussed, I think this can be relaxed.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
| datatype=datatype | ||
| ) | ||
|
|
||
| # FIXME Make sure the shape is all in the symbol table. We know that | ||
| # all symbols we find can be safely added as otherwise validate will | ||
| # fail. | ||
| # This is an oversimplification because we could have multiple | ||
| # references to the same symbol... | ||
| if isinstance(datatype, ArrayType): | ||
| for element in datatype.shape: | ||
| symbols = set() | ||
| if isinstance(element.lower, DataNode): | ||
| symbols.update(element.lower.get_all_accessed_symbols()) | ||
| if isinstance(element.upper, DataNode): | ||
| symbols.update(element.upper.get_all_accessed_symbols()) | ||
| scope_symbols = node.scope.symbol_table.get_symbols() | ||
| for sym in symbols: | ||
| scoped_name_sym = scope_symbols.get(sym.name, None) | ||
| # If no symbol with the name exists then create one. | ||
| if not scoped_name_sym: | ||
| sym_copy = sym.copy() | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I hadn't understood what you meant by "import symbols" before but now I do. I guess the premise here is that if a symbol appears in the definition of an imported symbol's shape, then it must be a compile-time constant? (Because if it was allocatable, we wouldn't have sizes for its dimensions.) In which case I think it's OK. Presumably also, we're only adding these symbols into the table associated with a Routine? (I started worrying that they might have module scope, in which case we would need to make sure they were private.)
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Worth summarising this in a comment I think - possibly instead of the FIXME above?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not quite, if we have a symbol (that needs to be imported from another module and currently isn't) in the shape then we add it to the symbol table with an ImportInterface, so the output should be |
||
| if isinstance(sym_copy.interface, ImportInterface): | ||
| # Check if the ContainerSymbol is already in the | ||
| # interface | ||
| container = scope_symbols.get( | ||
| sym_copy.interface.container_symbol.name, | ||
| None | ||
| ) | ||
| if container is None: | ||
| # Add the container symbol the the symbol table | ||
| # and we're ok with this symbol. | ||
| node.scope.symbol_table.add( | ||
| sym_copy.interface.container_symbol | ||
| ) | ||
| # If we find the container then we need to update | ||
| # the interface to use the container listed. | ||
| else: | ||
| sym_copy.interface.container_symbol = \ | ||
| container | ||
| node.scope.symbol_table.add(sym_copy) | ||
|
|
||
| # Create a Reference to the new symbol | ||
| new_ref = Reference(symbol) | ||
|
|
||
| # Find the parent and position of the statement containing the | ||
| # DataNode. | ||
| parent = node.ancestor(Statement).parent | ||
arporter marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| pos = node.ancestor(Statement).position | ||
|
|
||
| # Replace the datanode with the new reference | ||
| node.replace_with(new_ref) | ||
|
|
||
| # Create an assignment to set the value of the new symbol | ||
| assign = Assignment.create(new_ref.copy(), node) | ||
|
|
||
| # Add the assignment into the tree. | ||
| parent.addchild(assign, pos) | ||
arporter marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| __all__ = ["DataNodeExtractTrans"] | ||
Uh oh!
There was an error while loading. Please reload this page.