Skip to content
322 changes: 286 additions & 36 deletions src/ecooptimizer/refactorers/concrete/long_parameter_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,53 +260,195 @@ def update_parameter_usages(
):
"""
Updates the function body to use encapsulated parameter objects.
This method transforms parameter references in the function body to use new data_params
and config_params objects.

Args:
function_node: CST node of the function to transform
classified_params: Dictionary mapping parameter groups ('data_params' or 'config_params')
to lists of parameter names in each group

Returns:
The transformed function node with updated parameter usages
"""
# Create a module with just the function to get metadata
module = cst.Module(body=[function_node])
wrapper = MetadataWrapper(module)

class ParameterUsageTransformer(cst.CSTTransformer):
def __init__(self, classified_params: dict[str, list[str]]):
self.param_to_group = {}
"""
A CST transformer that updates parameter references to use the new parameter objects.
"""

METADATA_DEPENDENCIES = (ParentNodeProvider,)

def __init__(
self, classified_params: dict[str, list[str]], metadata_wrapper: MetadataWrapper
):
super().__init__()
# map each parameter to its group (data_params or config_params)
self.param_to_group = {}
self.parent_provider = metadata_wrapper.resolve(ParentNodeProvider)
# flatten classified_params to map each param to its group (dataParams or configParams)
for group, params in classified_params.items():
for param in params:
self.param_to_group[param] = group

def leave_Assign(
self,
original_node: cst.Assign, # noqa: ARG002
updated_node: cst.Assign,
) -> cst.Assign:
def is_in_assignment_target(self, node: cst.CSTNode) -> bool:
"""
Transform only right-hand side references to parameters that need to be updated.
Ensure left-hand side (self attributes) remain unchanged.
Check if a node is part of an assignment target (left side of =).

Args:
node: The CST node to check

Returns:
True if the node is part of an assignment target that should not be transformed,
False otherwise
"""
current = node
while current:
parent = self.parent_provider.get(current)

# if we're at an AssignTarget, check if it's a simple Name assignment
if isinstance(parent, cst.AssignTarget):
if isinstance(current, cst.Name):
# allow transformation for simple parameter assignments
return False
return True

if isinstance(parent, cst.Assign):
# if we reach an Assign node, check if we came from the targets
for target in parent.targets:
if target.target.deep_equals(current):
if isinstance(current, cst.Name):
# allow transformation for simple parameter assignments
return False
return True
return False

if isinstance(parent, cst.Module):
return False

current = parent
return False

def leave_Name(
self, original_node: cst.Name, updated_node: cst.Name
) -> cst.BaseExpression:
"""
if not isinstance(updated_node.value, cst.Name):
Transform standalone parameter references.

Skip transformation if:
1. The name is part of an attribute access (eg: self.param)
2. The name is part of a complex assignment target (eg: self.x = y)

Transform if:
1. The name is a simple parameter being assigned (eg: param1 = value)
2. The name is used as a value (eg: x = param1)

Args:
original_node: The original Name node
updated_node: The current state of the Name node

Returns:
The transformed node or the original if no transformation is needed
"""
# dont't transform if this is part of a complex assignment target
if self.is_in_assignment_target(original_node):
return updated_node

var_name = updated_node.value.value
# dont't transform if this is part of an attribute access (e.g., self.param)
parent = self.parent_provider.get(original_node)
if isinstance(parent, cst.Attribute) and original_node is parent.attr:
return updated_node

if var_name in self.param_to_group:
new_value = cst.Attribute(
value=cst.Name(self.param_to_group[var_name]), attr=cst.Name(var_name)
name_value = updated_node.value
if name_value in self.param_to_group:
# transform the name into an attribute access on the appropriate parameter object
return cst.Attribute(
value=cst.Name(self.param_to_group[name_value]), attr=cst.Name(name_value)
)
return updated_node

def leave_Attribute(
self, original_node: cst.Attribute, updated_node: cst.Attribute
) -> cst.BaseExpression:
"""
Handle method calls and attribute access on parameters.
This method handles several cases:

1. Assignment targets (eg: self.x = y)
2. Simple attribute access (eg: self.x or report.x)
3. Nested attribute access (eg: data_params.user_id)
4. Subscript access (eg: self.settings["timezone"])
5. Parameter attribute access (eg: username.strip())

Args:
original_node: The original Attribute node
updated_node: The current state of the Attribute node

Returns:
The transformed node or the original if no transformation is needed
"""
# don't transform if this is part of an assignment target
if self.is_in_assignment_target(original_node):
# if this is a simple attribute access (eg: self.x or report.x), don't transform it
if isinstance(updated_node.value, cst.Name) and updated_node.value.value in {
"self",
"report",
}:
return original_node
return updated_node

# if this is a nested attribute access (eg: data_params.user_id), don't transform it further
if (
isinstance(updated_node.value, cst.Attribute)
and isinstance(updated_node.value.value, cst.Name)
and updated_node.value.value.value in {"data_params", "config_params"}
):
return updated_node

# if this is a simple attribute access (eg: self.x or report.x), don't transform it
if isinstance(updated_node.value, cst.Name) and updated_node.value.value in {
"self",
"report",
}:
# check if this is part of a subscript target (eg: self.settings["timezone"])
parent = self.parent_provider.get(original_node)
if isinstance(parent, cst.Subscript):
return original_node
# check if this is part of a subscript value
if isinstance(parent, cst.SubscriptElement):
return original_node
return original_node

# if the attribute's value is a parameter name, update it to use the encapsulated parameter object
if (
isinstance(updated_node.value, cst.Name)
and updated_node.value.value in self.param_to_group
):
param_name = updated_node.value.value
return cst.Attribute(
value=cst.Name(self.param_to_group[param_name]), attr=updated_node.attr
)
return updated_node.with_changes(value=new_value)

return updated_node

# wrap CST node in a MetadataWrapper to enable metadata analysis
transformer = ParameterUsageTransformer(classified_params)
return function_node.visit(transformer)
# create transformer with metadata wrapper
transformer = ParameterUsageTransformer(classified_params, wrapper)
# transform the function body
updated_module = module.visit(transformer)
# return the transformed function
return updated_module.body[0]

@staticmethod
def get_enclosing_class_name(
tree: cst.Module, # noqa: ARG004
init_node: cst.FunctionDef,
parent_metadata: Mapping[cst.CSTNode, cst.CSTNode],
) -> Optional[str]:
"""
Finds the class name enclosing the given __init__ function node.
"""
# wrapper = MetadataWrapper(tree)
current_node = init_node
while current_node in parent_metadata:
parent = parent_metadata[current_node]
Expand All @@ -324,15 +466,7 @@ def update_function_calls(
classified_param_names: tuple[str, str],
enclosing_class_name: str,
) -> cst.Module:
"""
Updates all calls to a given function in the provided CST tree to reflect new encapsulated parameters
:param tree: CST tree of the code.
:param function_node: CST node of the function to update calls for.
:param params: A dictionary containing 'data' and 'config' parameters.
:return: The updated CST tree
"""
param_to_group = {}

for group_name, params in zip(classified_param_names, classified_params.values()):
for param in params:
param_to_group[param] = group_name
Expand All @@ -341,6 +475,15 @@ def update_function_calls(
if function_name == "__init__":
function_name = enclosing_class_name

# Get all parameter names from the function definition
all_param_names = [p.name.value for p in function_node.params.params]
# Find where variadic args start (if any)
variadic_start = len(all_param_names)
for i, param in enumerate(function_node.params.params):
if param.star == "*" or param.star == "**":
variadic_start = i
break

class FunctionCallTransformer(cst.CSTTransformer):
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call: # noqa: ARG002
"""Transforms function calls to use grouped parameters."""
Expand All @@ -361,13 +504,27 @@ def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Cal

positional_args = []
keyword_args = {}

# Separate positional and keyword arguments
for arg in updated_node.args:
if arg.keyword is None:
positional_args.append(arg.value)
else:
keyword_args[arg.keyword.value] = arg.value
variadic_args = []
variadic_kwargs = {}

# Separate positional, keyword, and variadic arguments
for i, arg in enumerate(updated_node.args):
if isinstance(arg, cst.Arg):
if arg.keyword is None:
# If this is a positional argument beyond the number of parameters,
# it's a variadic arg
if i >= variadic_start:
variadic_args.append(arg.value)
elif i < len(used_params):
positional_args.append(arg.value)
else:
# If this is a keyword argument for a used parameter, keep it
if arg.keyword.value in param_to_group:
keyword_args[arg.keyword.value] = arg.value
# If this is a keyword argument not in the original parameters,
# it's a variadic kwarg
elif arg.keyword.value not in all_param_names:
variadic_kwargs[arg.keyword.value] = arg.value

# Group arguments based on classified_params
grouped_args = {group: [] for group in classified_param_names}
Expand Down Expand Up @@ -397,6 +554,94 @@ def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Cal
if grouped_args[group_name] # Skip empty groups
]

# Add variadic positional arguments
new_args.extend([cst.Arg(value=arg) for arg in variadic_args])

# Add variadic keyword arguments
new_args.extend(
[
cst.Arg(keyword=cst.Name(key), value=value)
for key, value in variadic_kwargs.items()
]
)

return updated_node.with_changes(args=new_args)

transformer = FunctionCallTransformer()
return tree.visit(transformer)

@staticmethod
def update_function_calls_unclassified(
tree: cst.Module,
function_node: cst.FunctionDef,
used_params: list[str],
enclosing_class_name: str,
) -> cst.Module:
"""
Updates all calls to a given function to only include used parameters.
This is used when parameters are removed without being classified into objects.

Args:
tree: CST tree of the code
function_node: CST node of the function to update calls for
used_params: List of parameter names that are actually used in the function
enclosing_class_name: Name of the enclosing class if this is a method

Returns:
Updated CST tree with modified function calls
"""
function_name = function_node.name.value
if function_name == "__init__":
function_name = enclosing_class_name

class FunctionCallTransformer(cst.CSTTransformer):
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call: # noqa: ARG002
"""Transforms function calls to only include used parameters."""
# handle both standalone function calls and instance method calls
if not isinstance(updated_node.func, (cst.Name, cst.Attribute)):
return updated_node

# extract the function/method name
func_name = (
updated_node.func.attr.value
if isinstance(updated_node.func, cst.Attribute)
else updated_node.func.value
)

# if not the target function, leave unchanged
if func_name != function_name:
return updated_node

# map original parameters to their positions
param_positions = {
param.name.value: i for i, param in enumerate(function_node.params.params)
}

# keep track of which positions in the argument list correspond to used parameters
used_positions = {i for param, i in param_positions.items() if param in used_params}

new_args = []
pos_arg_count = 0

# process all arguments
for arg in updated_node.args:
if arg.keyword is None:
# handle positional arguments
if pos_arg_count in used_positions:
new_args.append(arg)
pos_arg_count += 1
else:
# handle keyword arguments
if arg.keyword.value in used_params:
# keep keyword arguments for used parameters
new_args.append(arg)

# ensure the last argument does not have a trailing comma
if new_args:
final_args = new_args[:-1]
final_args.append(new_args[-1].with_changes(comma=cst.MaybeSentinel.DEFAULT))
new_args = final_args

return updated_node.with_changes(args=new_args)

transformer = FunctionCallTransformer()
Expand Down Expand Up @@ -499,7 +744,7 @@ def refactor(
self.is_constructor = self.function_node.name.value == "__init__"
if self.is_constructor:
self.enclosing_class_name = FunctionCallUpdater.get_enclosing_class_name(
tree, self.function_node, parent_metadata
self.function_node, parent_metadata
)
param_names = [
param.name.value
Expand Down Expand Up @@ -562,6 +807,11 @@ def refactor(
self.function_node, self.used_params, default_value_params
)

# update all calls to match the new signature
tree = self.function_updater.update_function_calls_unclassified(
tree, self.function_node, self.used_params, self.enclosing_class_name
)

class FunctionReplacer(cst.CSTTransformer):
def __init__(
self, original_function: cst.FunctionDef, updated_function: cst.FunctionDef
Expand Down
Loading
Loading