diff --git a/src/ecooptimizer/refactorers/concrete/long_parameter_list.py b/src/ecooptimizer/refactorers/concrete/long_parameter_list.py index 4b1205d8..063ee3de 100644 --- a/src/ecooptimizer/refactorers/concrete/long_parameter_list.py +++ b/src/ecooptimizer/refactorers/concrete/long_parameter_list.py @@ -260,53 +260,195 @@ def update_parameter_usages( ): """ Updates the function body to use encapsulated parameter objects. + This method transforms parameter references in the function body to use new data_params + and config_params objects. + + Args: + function_node: CST node of the function to transform + classified_params: Dictionary mapping parameter groups ('data_params' or 'config_params') + to lists of parameter names in each group + + Returns: + The transformed function node with updated parameter usages """ + # Create a module with just the function to get metadata + module = cst.Module(body=[function_node]) + wrapper = MetadataWrapper(module) class ParameterUsageTransformer(cst.CSTTransformer): - def __init__(self, classified_params: dict[str, list[str]]): - self.param_to_group = {} + """ + A CST transformer that updates parameter references to use the new parameter objects. + """ + METADATA_DEPENDENCIES = (ParentNodeProvider,) + + def __init__( + self, classified_params: dict[str, list[str]], metadata_wrapper: MetadataWrapper + ): + super().__init__() + # map each parameter to its group (data_params or config_params) + self.param_to_group = {} + self.parent_provider = metadata_wrapper.resolve(ParentNodeProvider) # flatten classified_params to map each param to its group (dataParams or configParams) for group, params in classified_params.items(): for param in params: self.param_to_group[param] = group - def leave_Assign( - self, - original_node: cst.Assign, # noqa: ARG002 - updated_node: cst.Assign, - ) -> cst.Assign: + def is_in_assignment_target(self, node: cst.CSTNode) -> bool: """ - Transform only right-hand side references to parameters that need to be updated. - Ensure left-hand side (self attributes) remain unchanged. + Check if a node is part of an assignment target (left side of =). + + Args: + node: The CST node to check + + Returns: + True if the node is part of an assignment target that should not be transformed, + False otherwise + """ + current = node + while current: + parent = self.parent_provider.get(current) + + # if we're at an AssignTarget, check if it's a simple Name assignment + if isinstance(parent, cst.AssignTarget): + if isinstance(current, cst.Name): + # allow transformation for simple parameter assignments + return False + return True + + if isinstance(parent, cst.Assign): + # if we reach an Assign node, check if we came from the targets + for target in parent.targets: + if target.target.deep_equals(current): + if isinstance(current, cst.Name): + # allow transformation for simple parameter assignments + return False + return True + return False + + if isinstance(parent, cst.Module): + return False + + current = parent + return False + + def leave_Name( + self, original_node: cst.Name, updated_node: cst.Name + ) -> cst.BaseExpression: """ - if not isinstance(updated_node.value, cst.Name): + Transform standalone parameter references. + + Skip transformation if: + 1. The name is part of an attribute access (eg: self.param) + 2. The name is part of a complex assignment target (eg: self.x = y) + + Transform if: + 1. The name is a simple parameter being assigned (eg: param1 = value) + 2. The name is used as a value (eg: x = param1) + + Args: + original_node: The original Name node + updated_node: The current state of the Name node + + Returns: + The transformed node or the original if no transformation is needed + """ + # dont't transform if this is part of a complex assignment target + if self.is_in_assignment_target(original_node): return updated_node - var_name = updated_node.value.value + # dont't transform if this is part of an attribute access (e.g., self.param) + parent = self.parent_provider.get(original_node) + if isinstance(parent, cst.Attribute) and original_node is parent.attr: + return updated_node - if var_name in self.param_to_group: - new_value = cst.Attribute( - value=cst.Name(self.param_to_group[var_name]), attr=cst.Name(var_name) + name_value = updated_node.value + if name_value in self.param_to_group: + # transform the name into an attribute access on the appropriate parameter object + return cst.Attribute( + value=cst.Name(self.param_to_group[name_value]), attr=cst.Name(name_value) + ) + return updated_node + + def leave_Attribute( + self, original_node: cst.Attribute, updated_node: cst.Attribute + ) -> cst.BaseExpression: + """ + Handle method calls and attribute access on parameters. + This method handles several cases: + + 1. Assignment targets (eg: self.x = y) + 2. Simple attribute access (eg: self.x or report.x) + 3. Nested attribute access (eg: data_params.user_id) + 4. Subscript access (eg: self.settings["timezone"]) + 5. Parameter attribute access (eg: username.strip()) + + Args: + original_node: The original Attribute node + updated_node: The current state of the Attribute node + + Returns: + The transformed node or the original if no transformation is needed + """ + # don't transform if this is part of an assignment target + if self.is_in_assignment_target(original_node): + # if this is a simple attribute access (eg: self.x or report.x), don't transform it + if isinstance(updated_node.value, cst.Name) and updated_node.value.value in { + "self", + "report", + }: + return original_node + return updated_node + + # if this is a nested attribute access (eg: data_params.user_id), don't transform it further + if ( + isinstance(updated_node.value, cst.Attribute) + and isinstance(updated_node.value.value, cst.Name) + and updated_node.value.value.value in {"data_params", "config_params"} + ): + return updated_node + + # if this is a simple attribute access (eg: self.x or report.x), don't transform it + if isinstance(updated_node.value, cst.Name) and updated_node.value.value in { + "self", + "report", + }: + # check if this is part of a subscript target (eg: self.settings["timezone"]) + parent = self.parent_provider.get(original_node) + if isinstance(parent, cst.Subscript): + return original_node + # check if this is part of a subscript value + if isinstance(parent, cst.SubscriptElement): + return original_node + return original_node + + # if the attribute's value is a parameter name, update it to use the encapsulated parameter object + if ( + isinstance(updated_node.value, cst.Name) + and updated_node.value.value in self.param_to_group + ): + param_name = updated_node.value.value + return cst.Attribute( + value=cst.Name(self.param_to_group[param_name]), attr=updated_node.attr ) - return updated_node.with_changes(value=new_value) return updated_node - # wrap CST node in a MetadataWrapper to enable metadata analysis - transformer = ParameterUsageTransformer(classified_params) - return function_node.visit(transformer) + # create transformer with metadata wrapper + transformer = ParameterUsageTransformer(classified_params, wrapper) + # transform the function body + updated_module = module.visit(transformer) + # return the transformed function + return updated_module.body[0] @staticmethod def get_enclosing_class_name( - tree: cst.Module, # noqa: ARG004 init_node: cst.FunctionDef, parent_metadata: Mapping[cst.CSTNode, cst.CSTNode], ) -> Optional[str]: """ Finds the class name enclosing the given __init__ function node. """ - # wrapper = MetadataWrapper(tree) current_node = init_node while current_node in parent_metadata: parent = parent_metadata[current_node] @@ -324,15 +466,7 @@ def update_function_calls( classified_param_names: tuple[str, str], enclosing_class_name: str, ) -> cst.Module: - """ - Updates all calls to a given function in the provided CST tree to reflect new encapsulated parameters - :param tree: CST tree of the code. - :param function_node: CST node of the function to update calls for. - :param params: A dictionary containing 'data' and 'config' parameters. - :return: The updated CST tree - """ param_to_group = {} - for group_name, params in zip(classified_param_names, classified_params.values()): for param in params: param_to_group[param] = group_name @@ -341,6 +475,15 @@ def update_function_calls( if function_name == "__init__": function_name = enclosing_class_name + # Get all parameter names from the function definition + all_param_names = [p.name.value for p in function_node.params.params] + # Find where variadic args start (if any) + variadic_start = len(all_param_names) + for i, param in enumerate(function_node.params.params): + if param.star == "*" or param.star == "**": + variadic_start = i + break + class FunctionCallTransformer(cst.CSTTransformer): def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call: # noqa: ARG002 """Transforms function calls to use grouped parameters.""" @@ -361,13 +504,27 @@ def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Cal positional_args = [] keyword_args = {} - - # Separate positional and keyword arguments - for arg in updated_node.args: - if arg.keyword is None: - positional_args.append(arg.value) - else: - keyword_args[arg.keyword.value] = arg.value + variadic_args = [] + variadic_kwargs = {} + + # Separate positional, keyword, and variadic arguments + for i, arg in enumerate(updated_node.args): + if isinstance(arg, cst.Arg): + if arg.keyword is None: + # If this is a positional argument beyond the number of parameters, + # it's a variadic arg + if i >= variadic_start: + variadic_args.append(arg.value) + elif i < len(used_params): + positional_args.append(arg.value) + else: + # If this is a keyword argument for a used parameter, keep it + if arg.keyword.value in param_to_group: + keyword_args[arg.keyword.value] = arg.value + # If this is a keyword argument not in the original parameters, + # it's a variadic kwarg + elif arg.keyword.value not in all_param_names: + variadic_kwargs[arg.keyword.value] = arg.value # Group arguments based on classified_params grouped_args = {group: [] for group in classified_param_names} @@ -397,6 +554,94 @@ def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Cal if grouped_args[group_name] # Skip empty groups ] + # Add variadic positional arguments + new_args.extend([cst.Arg(value=arg) for arg in variadic_args]) + + # Add variadic keyword arguments + new_args.extend( + [ + cst.Arg(keyword=cst.Name(key), value=value) + for key, value in variadic_kwargs.items() + ] + ) + + return updated_node.with_changes(args=new_args) + + transformer = FunctionCallTransformer() + return tree.visit(transformer) + + @staticmethod + def update_function_calls_unclassified( + tree: cst.Module, + function_node: cst.FunctionDef, + used_params: list[str], + enclosing_class_name: str, + ) -> cst.Module: + """ + Updates all calls to a given function to only include used parameters. + This is used when parameters are removed without being classified into objects. + + Args: + tree: CST tree of the code + function_node: CST node of the function to update calls for + used_params: List of parameter names that are actually used in the function + enclosing_class_name: Name of the enclosing class if this is a method + + Returns: + Updated CST tree with modified function calls + """ + function_name = function_node.name.value + if function_name == "__init__": + function_name = enclosing_class_name + + class FunctionCallTransformer(cst.CSTTransformer): + def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call: # noqa: ARG002 + """Transforms function calls to only include used parameters.""" + # handle both standalone function calls and instance method calls + if not isinstance(updated_node.func, (cst.Name, cst.Attribute)): + return updated_node + + # extract the function/method name + func_name = ( + updated_node.func.attr.value + if isinstance(updated_node.func, cst.Attribute) + else updated_node.func.value + ) + + # if not the target function, leave unchanged + if func_name != function_name: + return updated_node + + # map original parameters to their positions + param_positions = { + param.name.value: i for i, param in enumerate(function_node.params.params) + } + + # keep track of which positions in the argument list correspond to used parameters + used_positions = {i for param, i in param_positions.items() if param in used_params} + + new_args = [] + pos_arg_count = 0 + + # process all arguments + for arg in updated_node.args: + if arg.keyword is None: + # handle positional arguments + if pos_arg_count in used_positions: + new_args.append(arg) + pos_arg_count += 1 + else: + # handle keyword arguments + if arg.keyword.value in used_params: + # keep keyword arguments for used parameters + new_args.append(arg) + + # ensure the last argument does not have a trailing comma + if new_args: + final_args = new_args[:-1] + final_args.append(new_args[-1].with_changes(comma=cst.MaybeSentinel.DEFAULT)) + new_args = final_args + return updated_node.with_changes(args=new_args) transformer = FunctionCallTransformer() @@ -499,7 +744,7 @@ def refactor( self.is_constructor = self.function_node.name.value == "__init__" if self.is_constructor: self.enclosing_class_name = FunctionCallUpdater.get_enclosing_class_name( - tree, self.function_node, parent_metadata + self.function_node, parent_metadata ) param_names = [ param.name.value @@ -562,6 +807,11 @@ def refactor( self.function_node, self.used_params, default_value_params ) + # update all calls to match the new signature + tree = self.function_updater.update_function_calls_unclassified( + tree, self.function_node, self.used_params, self.enclosing_class_name + ) + class FunctionReplacer(cst.CSTTransformer): def __init__( self, original_function: cst.FunctionDef, updated_function: cst.FunctionDef diff --git a/tests/refactorers/test_long_parameter_list_refactor.py b/tests/refactorers/test_long_parameter_list_refactor.py index ad26dcea..4104283d 100644 --- a/tests/refactorers/test_long_parameter_list_refactor.py +++ b/tests/refactorers/test_long_parameter_list_refactor.py @@ -144,11 +144,6 @@ def __init__(self, data_params, config_params): refactorer.refactor(test_file, test_dir, smell, test_file) modified_code = test_file.read_text() - print("***************************************") - print(modified_code.strip()) - print("***************************************") - print(expected_modified_code.strip()) - print("***************************************") assert modified_code.strip() == expected_modified_code.strip() # cleanup after test @@ -306,11 +301,6 @@ def generate_report_partial(data_params, config_params): refactorer.refactor(test_file, test_dir, smell, test_file) modified_code = test_file.read_text() - print("***************************************") - print(modified_code.strip()) - print("***************************************") - print(expected_modified_code.strip()) - print("***************************************") assert modified_code.strip() == expected_modified_code.strip() # cleanup after test @@ -378,3 +368,690 @@ def create_partial_report(data_params, config_params): # cleanup after test test_file.unlink() test_dir.rmdir() + + +def test_lpl_most_unused_params(refactorer, source_files): + """Test for function with 8 params that has 5 parameters unused, refactoring should only remove unused parameters""" + + test_dir = source_files / "temp_test_lpl" + test_dir.mkdir(parents=True, exist_ok=True) + + test_file = test_dir / "fake.py" + + code = textwrap.dedent("""\ + def create_partial_report(user_id, username, email, preferences, timezone_config, language, notification_settings, active_status=None): + report = {} + report.user_id = user_id + report.username = username + + create_partial_report(2, "janedoe", "janedoe@example.com", {"theme": "light"}, "PST", "en", notification_settings=False) + """) + + expected_modified_code = textwrap.dedent("""\ + def create_partial_report(user_id, username): + report = {} + report.user_id = user_id + report.username = username + + create_partial_report(2, "janedoe") + """) + test_file.write_text(code) + smell = create_smell([1])() + refactorer.refactor(test_file, test_dir, smell, test_file) + + modified_code = test_file.read_text() + assert modified_code.strip() == expected_modified_code.strip() + + +def test_lpl_method_operations(refactorer, source_files): + """Test for function with 8 params that performs operations on parameters""" + + test_dir = source_files / "temp_test_lpl" + test_dir.mkdir(parents=True, exist_ok=True) + + test_file = test_dir / "fake.py" + + code = textwrap.dedent("""\ + def process_user_data(username, email, age, address, phone, preferences, settings, notifications): + \"\"\"Process user data and return a formatted result.\"\"\" + # Process the data + full_name = username.strip() + contact_email = email.lower() + user_age = age + 1 + formatted_address = address.replace(',', '') + clean_phone = phone.replace('-', '') + user_prefs = preferences.copy() + user_settings = settings.copy() + notif_list = notifications.copy() + return { + 'name': full_name, + 'email': contact_email, + 'age': user_age, + 'address': formatted_address, + 'phone': clean_phone, + 'preferences': user_prefs, + 'settings': user_settings, + 'notifications': notif_list + } + """) + + expected_modified_code = textwrap.dedent("""\ + class DataParams_process_user_data_1: + def __init__(self, username, email, age, address, phone, preferences, notifications): + self.username = username + self.email = email + self.age = age + self.address = address + self.phone = phone + self.preferences = preferences + self.notifications = notifications + class ConfigParams_process_user_data_1: + def __init__(self, settings): + self.settings = settings + def process_user_data(data_params, config_params): + \"\"\"Process user data and return a formatted result.\"\"\" + # Process the data + full_name = data_params.username.strip() + contact_email = data_params.email.lower() + user_age = data_params.age + 1 + formatted_address = data_params.address.replace(',', '') + clean_phone = data_params.phone.replace('-', '') + user_prefs = data_params.preferences.copy() + user_settings = config_params.settings.copy() + notif_list = data_params.notifications.copy() + return { + 'name': full_name, + 'email': contact_email, + 'age': user_age, + 'address': formatted_address, + 'phone': clean_phone, + 'preferences': user_prefs, + 'settings': user_settings, + 'notifications': notif_list + } + """) + test_file.write_text(code) + smell = create_smell([1])() + refactorer.refactor(test_file, test_dir, smell, test_file) + + modified_code = test_file.read_text() + assert modified_code.strip() == expected_modified_code.strip() + + # cleanup after test + test_file.unlink() + test_dir.rmdir() + + +def test_lpl_parameter_assignments(refactorer, source_files): + """Test for handling parameter assignments and transformations in various contexts""" + + test_dir = source_files / "temp_test_lpl" + test_dir.mkdir(parents=True, exist_ok=True) + + test_file = test_dir / "fake.py" + + code = textwrap.dedent("""\ + class DataProcessor: + def process_data(self, input_data, output_format, config_path, temp_path, cache_path, log_path, backup_path, format_options): + # Simple parameter assignment + backup_path = "/new/backup/path" + + # Parameter used in computation + cache_path = temp_path + "/cache" + + # Parameter assigned to attribute + self.config = config_path + + # Parameter used in method call + output_format = output_format.strip() + + # Parameter used in dictionary + paths = { + "input": input_data, + "output": output_format, + "config": config_path, + "temp": temp_path, + "cache": cache_path, + "log": log_path, + "backup": backup_path + } + + # Parameter used in list + all_paths = [ + input_data, + output_format, + config_path, + temp_path, + cache_path, + log_path, + backup_path + ] + + # Use format options + formatted = format_options.get("style", "default") + + return paths, all_paths, formatted + + processor = DataProcessor() + result = processor.process_data( + "/input", + "json", + "/config", + "/temp", + "/cache", + "/logs", + "/backup", + {"style": "pretty"} + ) + """) + + expected_modified_code = textwrap.dedent("""\ + class DataParams_process_data_2: + def __init__(self, input_data, output_format): + self.input_data = input_data + self.output_format = output_format + class ConfigParams_process_data_2: + def __init__(self, config_path, temp_path, cache_path, log_path, backup_path, format_options): + self.config_path = config_path + self.temp_path = temp_path + self.cache_path = cache_path + self.log_path = log_path + self.backup_path = backup_path + self.format_options = format_options + class DataProcessor: + def process_data(self, data_params, config_params): + # Simple parameter assignment + config_params.backup_path = "/new/backup/path" + + # Parameter used in computation + config_params.cache_path = config_params.temp_path + "/cache" + + # Parameter assigned to attribute + self.config = config_params.config_path + + # Parameter used in method call + data_params.output_format = data_params.output_format.strip() + + # Parameter used in dictionary + paths = { + "input": data_params.input_data, + "output": data_params.output_format, + "config": config_params.config_path, + "temp": config_params.temp_path, + "cache": config_params.cache_path, + "log": config_params.log_path, + "backup": config_params.backup_path + } + + # Parameter used in list + all_paths = [ + data_params.input_data, + data_params.output_format, + config_params.config_path, + config_params.temp_path, + config_params.cache_path, + config_params.log_path, + config_params.backup_path + ] + + # Use format options + formatted = config_params.format_options.get("style", "default") + + return paths, all_paths, formatted + + processor = DataProcessor() + result = processor.process_data( + DataParams_process_data_2("/input", "json"), ConfigParams_process_data_2("/config", "/temp", "/cache", "/logs", "/backup", {"style": "pretty"})) + """) + test_file.write_text(code) + smell = create_smell([2])() + refactorer.refactor(test_file, test_dir, smell, test_file) + + modified_code = test_file.read_text() + assert modified_code.strip() == expected_modified_code.strip() + + # cleanup after test + test_file.unlink() + test_dir.rmdir() + + +def test_lpl_with_args_kwargs(refactorer, source_files): + """Test for function with *args and **kwargs""" + + test_dir = source_files / "temp_test_lpl" + test_dir.mkdir(parents=True, exist_ok=True) + + test_file = test_dir / "fake.py" + + code = textwrap.dedent("""\ + def process_data(user_id, username, email, preferences, timezone_config, language, notification_settings, *args, **kwargs): + report = {} + # Use all regular parameters + report.user_id = user_id + report.username = username + report.email = email + report.preferences = preferences + report.timezone = timezone_config + report.language = language + report.notifications = notification_settings + + # Use *args + for arg in args: + report.setdefault("extra_data", []).append(arg) + + # Use **kwargs + for key, value in kwargs.items(): + report[key] = value + + return report + + # Test call with various argument types + result = process_data( + 2, + "janedoe", + "janedoe@example.com", + {"theme": "light"}, + "PST", + "en", + False, + "extra1", + "extra2", + custom_field="custom_value", + another_field=123 + ) + """) + + expected_modified_code = textwrap.dedent("""\ + class DataParams_process_data_1: + def __init__(self, user_id, username, email, preferences, language): + self.user_id = user_id + self.username = username + self.email = email + self.preferences = preferences + self.language = language + class ConfigParams_process_data_1: + def __init__(self, timezone_config, notification_settings): + self.timezone_config = timezone_config + self.notification_settings = notification_settings + def process_data(data_params, config_params, *args, **kwargs): + report = {} + # Use all regular parameters + report.user_id = data_params.user_id + report.username = data_params.username + report.email = data_params.email + report.preferences = data_params.preferences + report.timezone = config_params.timezone_config + report.language = data_params.language + report.notifications = config_params.notification_settings + + # Use *args + for arg in args: + report.setdefault("extra_data", []).append(arg) + + # Use **kwargs + for key, value in kwargs.items(): + report[key] = value + + return report + + # Test call with various argument types + result = process_data( + DataParams_process_data_1(2, "janedoe", "janedoe@example.com", {"theme": "light"}, "en"), ConfigParams_process_data_1("PST", False), "extra1", "extra2", custom_field = "custom_value", another_field = 123)""") + test_file.write_text(code) + smell = create_smell([1])() + refactorer.refactor(test_file, test_dir, smell, test_file) + + modified_code = test_file.read_text() + assert modified_code.strip() == expected_modified_code.strip() + + # cleanup after test + test_file.unlink() + test_dir.rmdir() + + +def test_lpl_with_kwargs_only(refactorer, source_files): + """Test for function with **kwargs""" + + test_dir = source_files / "temp_test_lpl" + test_dir.mkdir(parents=True, exist_ok=True) + + test_file = test_dir / "fake.py" + + code = textwrap.dedent("""\ + def process_data_2(user_id, username, email, preferences, timezone_config, language, notification_settings, **kwargs): + report = {} + # Use all regular parameters + report.user_id = user_id + report.username = username + report.email = email + report.preferences.update(preferences) + report.timezone = timezone_config + report.language = language + report.notifications = notification_settings + + # Use **kwargs + for key, value in kwargs.items(): + report[key] = value # kwargs used + + # Additional processing using the parameters + if notification_settings: + report.timezone = f"{timezone_config}_notified" + + if "theme" in preferences: + report.language = f"{language}_{preferences['theme']}" + + return report + + # Test call with various argument types + result = process_data_2( + 2, + "janedoe", + "janedoe@example.com", + {"theme": "light"}, + "PST", + "en", + False, + custom_field="custom_value", + another_field=123 + ) + """) + + expected_modified_code = textwrap.dedent("""\ + class DataParams_process_data_2_1: + def __init__(self, user_id, username, email, preferences, language): + self.user_id = user_id + self.username = username + self.email = email + self.preferences = preferences + self.language = language + class ConfigParams_process_data_2_1: + def __init__(self, timezone_config, notification_settings): + self.timezone_config = timezone_config + self.notification_settings = notification_settings + def process_data_2(data_params, config_params, **kwargs): + report = {} + # Use all regular parameters + report.user_id = data_params.user_id + report.username = data_params.username + report.email = data_params.email + report.preferences.update(data_params.preferences) + report.timezone = config_params.timezone_config + report.language = data_params.language + report.notifications = config_params.notification_settings + + # Use **kwargs + for key, value in kwargs.items(): + report[key] = value # kwargs used + + # Additional processing using the parameters + if config_params.notification_settings: + report.timezone = f"{config_params.timezone_config}_notified" + + if "theme" in data_params.preferences: + report.language = f"{data_params.language}_{data_params.preferences['theme']}" + + return report + + # Test call with various argument types + result = process_data_2( + DataParams_process_data_2_1(2, "janedoe", "janedoe@example.com", {"theme": "light"}, "en"), ConfigParams_process_data_2_1("PST", False), custom_field = "custom_value", another_field = 123)""") + test_file.write_text(code) + smell = create_smell([1])() + refactorer.refactor(test_file, test_dir, smell, test_file) + + modified_code = test_file.read_text() + assert modified_code.strip() == expected_modified_code.strip() + + # cleanup after test + test_file.unlink() + test_dir.rmdir() + + +def test_lpl_complex_attribute_access(refactorer, source_files): + """Test for complex attribute access and nested parameter usage""" + + test_dir = source_files / "temp_test_lpl" + test_dir.mkdir(exist_ok=True) + + test_file = test_dir / "fake.py" + + code = textwrap.dedent("""\ + class DataProcessor: + def process_complex_data(self, user_data, setup_data, cache_data, log_data, temp_data, backup_data, format_data, extra_data): + # Complex attribute access and assignments + self.settings = { + "user": user_data, + "config": setup_data.settings, + "cache": cache_data.path, + "logs": log_data.directory, + "temp": temp_data.location, + "backup": backup_data.storage, + "format": format_data.options, + "extra": extra_data.metadata + } + + # Nested attribute access + if setup_data.settings["enabled"]: + user_data.preferences["theme"] = format_data.options["theme"] + + # Complex assignments + backup_data.storage["path"] = temp_data.location + "/backup" + cache_data.path = temp_data.location + "/cache" + + # Method calls on parameters + cleaned_user = user_data.name.strip().lower() + formatted_config = setup_data.format() + + # Dictionary comprehension using parameters + result = { + key: value.strip() + for key, value in user_data.metadata.items() + if key in setup_data.allowed_keys + } + + return result + + processor = DataProcessor() + result = processor.process_complex_data( + user_data={"name": " John ", "metadata": {"id": "123 ", "role": " admin "}}, + setup_data={"settings": {"enabled": True}, "allowed_keys": ["id"]}, + cache_data={"path": "/tmp/cache"}, + log_data={"directory": "/var/log"}, + temp_data={"location": "/tmp"}, + backup_data={"storage": {}}, + format_data={"options": {"theme": "dark"}}, + extra_data={"metadata": {}} + ) + """) + + expected_modified_code = textwrap.dedent("""\ + class DataParams_process_complex_data_2: + def __init__(self, user_data, setup_data, cache_data, log_data, temp_data, backup_data, format_data, extra_data): + self.user_data = user_data + self.setup_data = setup_data + self.cache_data = cache_data + self.log_data = log_data + self.temp_data = temp_data + self.backup_data = backup_data + self.format_data = format_data + self.extra_data = extra_data + class DataProcessor: + def process_complex_data(self, data_params): + # Complex attribute access and assignments + self.settings = { + "user": data_params.user_data, + "config": data_params.setup_data.settings, + "cache": data_params.cache_data.path, + "logs": data_params.log_data.directory, + "temp": data_params.temp_data.location, + "backup": data_params.backup_data.storage, + "format": data_params.format_data.options, + "extra": data_params.extra_data.metadata + } + + # Nested attribute access + if data_params.setup_data.settings["enabled"]: + data_params.user_data.preferences["theme"] = data_params.format_data.options["theme"] + + # Complex assignments + data_params.backup_data.storage["path"] = data_params.temp_data.location + "/backup" + data_params.cache_data.path = data_params.temp_data.location + "/cache" + + # Method calls on parameters + cleaned_user = data_params.user_data.name.strip().lower() + formatted_config = data_params.setup_data.format() + + # Dictionary comprehension using parameters + result = { + key: value.strip() + for key, value in data_params.user_data.metadata.items() + if key in data_params.setup_data.allowed_keys + } + + return result + + processor = DataProcessor() + result = processor.process_complex_data( + DataParams_process_complex_data_2(user_data = {"name": " John ", "metadata": {"id": "123 ", "role": " admin "}}, setup_data = {"settings": {"enabled": True}, "allowed_keys": ["id"]}, cache_data = {"path": "/tmp/cache"}, log_data = {"directory": "/var/log"}, temp_data = {"location": "/tmp"}, backup_data = {"storage": {}}, format_data = {"options": {"theme": "dark"}}, extra_data = {"metadata": {}})) + """) + test_file.write_text(code) + smell = create_smell([2])() + refactorer.refactor(test_file, test_dir, smell, test_file) + + modified_code = test_file.read_text() + assert modified_code.strip() == expected_modified_code.strip() + + # cleanup after test + test_file.unlink() + test_dir.rmdir() + + +def test_lpl_multi_file_refactor(refactorer, source_files): + """Test refactoring a function that is called from another file""" + + test_dir = source_files / "temp_test_lpl" + test_dir.mkdir(exist_ok=True) + + # Create the main file with function definition + main_file = test_dir / "main.py" + main_code = textwrap.dedent("""\ + def process_user_data(user_id, username, email, preferences, timezone_config, language, notification_settings, theme): + result = { + 'id': user_id, + 'name': username, + 'email': email, + 'prefs': preferences, + 'tz': timezone_config, + 'lang': language, + 'notif': notification_settings, + 'theme': theme + } + return result + """) + main_file.write_text(main_code) + + # Create another file that uses this function + user_file = test_dir / "user_processor.py" + user_code = textwrap.dedent("""\ + from main import process_user_data + + def handle_user(): + # Call with positional args + result1 = process_user_data( + 1, + "john", + "john@example.com", + {"theme": "light"}, + "PST", + "en", + False, + "dark" + ) + + # Call with keyword args + result2 = process_user_data( + user_id=2, + username="jane", + email="jane@example.com", + preferences={"theme": "dark"}, + timezone_config="UTC", + language="fr", + notification_settings=True, + theme="light" + ) + + return result1, result2 + """) + user_file.write_text(user_code) + + # Expected output for main.py + expected_main_code = textwrap.dedent("""\ + class DataParams_process_user_data_1: + def __init__(self, user_id, username, email, preferences, language, theme): + self.user_id = user_id + self.username = username + self.email = email + self.preferences = preferences + self.language = language + self.theme = theme + class ConfigParams_process_user_data_1: + def __init__(self, timezone_config, notification_settings): + self.timezone_config = timezone_config + self.notification_settings = notification_settings + def process_user_data(data_params, config_params): + result = { + 'id': data_params.user_id, + 'name': data_params.username, + 'email': data_params.email, + 'prefs': data_params.preferences, + 'tz': config_params.timezone_config, + 'lang': data_params.language, + 'notif': config_params.notification_settings, + 'theme': data_params.theme + } + return result + """) + + # Expected output for user_processor.py + expected_user_code = textwrap.dedent("""\ + from main import process_user_data + class DataParams_process_user_data_1: + def __init__(self, user_id, username, email, preferences, language, theme): + self.user_id = user_id + self.username = username + self.email = email + self.preferences = preferences + self.language = language + self.theme = theme + class ConfigParams_process_user_data_1: + def __init__(self, timezone_config, notification_settings): + self.timezone_config = timezone_config + self.notification_settings = notification_settings + + def handle_user(): + # Call with positional args + result1 = process_user_data( + DataParams_process_user_data_1(1, "john", "john@example.com", {"theme": "light"}, "en", "dark"), ConfigParams_process_user_data_1("PST", False)) + + # Call with keyword args + result2 = process_user_data( + DataParams_process_user_data_1(user_id = 2, username = "jane", email = "jane@example.com", preferences = {"theme": "dark"}, language = "fr", theme = "light"), ConfigParams_process_user_data_1(timezone_config = "UTC", notification_settings = True)) + + return result1, result2 + """) + + # Apply the refactoring + smell = create_smell([1])() + refactorer.refactor(main_file, test_dir, smell, main_file) + + # Verify both files were updated correctly + modified_main_code = main_file.read_text() + modified_user_code = user_file.read_text() + + assert modified_main_code.strip() == expected_main_code.strip() + assert modified_user_code.strip() == expected_user_code.strip() + + # cleanup after test + main_file.unlink() + user_file.unlink() + test_dir.rmdir()