diff --git a/.github/workflows/python-lint.yaml b/.github/workflows/python-lint.yaml index 133de123..c9bd3ab9 100644 --- a/.github/workflows/python-lint.yaml +++ b/.github/workflows/python-lint.yaml @@ -26,18 +26,6 @@ jobs: with: token: ${{ steps.app-token.outputs.token }} - # Get changed .py files - - name: Get changed .py files - id: changed-py-files - uses: tj-actions/changed-files@v45 - with: - files: | - **/*.py - files_ignore: | - tests/input/**/*.py - tests/_input_copies/**/*.py - diff_relative: true # Get the list of files relative to the repo root - - name: Install Python uses: actions/setup-python@v5 with: @@ -49,8 +37,6 @@ jobs: pip install ruff - name: Run Ruff - env: - ALL_CHANGED_FILES: ${{ steps.changed-py-files.outputs.all_changed_files }} run: | - ruff check $ALL_CHANGED_FILES --output-format=github . + ruff check --output-format=github diff --git a/.github/workflows/python-test.yaml b/.github/workflows/python-test.yaml index 45902a32..72533456 100644 --- a/.github/workflows/python-test.yaml +++ b/.github/workflows/python-test.yaml @@ -53,10 +53,3 @@ jobs: run: | git fetch origin ${{ github.base_ref }} diff-cover coverage.xml --compare-branch=origin/${{ github.base_ref }} --fail-under=80 - - - name: Check Per-File Coverage - run: | - for file in ${{ steps.changed-files.outputs.all_changed_files }}; do - echo "Checking overall coverage for $file" - coverage report --include=$file --fail-under=80 || exit 1 - done diff --git a/.gitignore b/.gitignore index 51b86108..3f8602fe 100644 --- a/.gitignore +++ b/.gitignore @@ -286,4 +286,25 @@ TSWLatexianTemp* # DRAW.IO files *.drawio -*.drawio.bkp \ No newline at end of file +*.drawio.bkp + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] + +.venv/ + +# Rope +.ropeproject + +*.egg-info/ + +# Package files +outputs/ +build/ +tests/temp_dir/ +tests/benchmarking/output/ + +# Coverage +.coverage +coverage.* \ No newline at end of file diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..17662ddc --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "plugin/capstone--sco-vs-code-plugin"] + path = plugin/capstone--sco-vs-code-plugin + url = https://github.com/ssm-lab/capstone--sco-vs-code-plugin.git diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..2ad9d923 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,11 @@ +repos: +- repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.7.4 + hooks: + # Run the linter. + - id: ruff + args: [ --fix ] + # Run the formatter. + - id: ruff-format + \ No newline at end of file diff --git a/docs/projMngmnt/Rev0_Team_Contrib.pdf b/docs/projMngmnt/Rev0_Team_Contrib.pdf index 4d8f2f1a..b614dae0 100644 Binary files a/docs/projMngmnt/Rev0_Team_Contrib.pdf and b/docs/projMngmnt/Rev0_Team_Contrib.pdf differ diff --git a/src/analyzers/__init__.py b/plugin/README.md similarity index 100% rename from src/analyzers/__init__.py rename to plugin/README.md diff --git a/plugin/capstone--sco-vs-code-plugin b/plugin/capstone--sco-vs-code-plugin new file mode 160000 index 00000000..55908450 --- /dev/null +++ b/plugin/capstone--sco-vs-code-plugin @@ -0,0 +1 @@ +Subproject commit 55908450f8041d4a4ad041eada803597bf5d0bfc diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..25181b22 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,141 @@ +[build-system] +requires = ["setuptools >= 61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "ecooptimizer" +version = "0.0.1" +dependencies = [ + "pylint", + "rope", + "astor", + "codecarbon", + "asttokens", + "uvicorn", + "fastapi", + "pydantic", + "libcst", + "websockets", +] +requires-python = ">=3.9" +authors = [ + { name = "Sevhena Walker" }, + { name = "Mya Hussain" }, + { name = "Nivetha Kuruparan" }, + { name = "Ayushi Amin" }, + { name = "Tanveer Brar" }, +] + +description = "A source code eco optimizer" +readme = "README.md" +license = { file = "LICENSE" } + +[project.optional-dependencies] +dev = [ + "pytest", + "pytest-cov", + "pytest-mock", + "ruff", + "coverage", + "pyright", + "pre-commit", +] + +[project.scripts] +eco-local = "ecooptimizer.__main__:main" +eco-ext = "ecooptimizer.api.__main__:main" +eco-ext-dev = "ecooptimizer.api.__main__:dev" + +[project.urls] +Documentation = "https://readthedocs.org" +Repository = "https://github.com/ssm-lab/capstone--source-code-optimizer" +"Bug Tracker" = "https://github.com/ssm-lab/capstone--source-code-optimizer/issues" + +[tool.pytest.ini_options] +norecursedirs = ["tests/temp*", "tests/input", "tests/_input_copies"] +addopts = ["--basetemp=tests/temp_dir"] +testpaths = ["tests"] +pythonpath = "src" + +[tool.coverage.run] +omit = [ + "*/__main__.py", + '*/__init__.py', + '*/utils/*', + "*/test_*.py", + "*/analyzers/*_analyzer.py", + "*/api/app.py", + "*/api/routes/show_logs.py", +] + +[tool.ruff] +extend-exclude = [ + "*tests/input/**/*.py", + "tests/_input_copies", + "tests/temp_dir", + "tests/benchmarking/test_code/**/*.py", +] +line-length = 100 + +[tool.ruff.lint] +select = [ + "E", # Enforce Python Error rules (e.g., syntax errors, exceptions). + "UP", # Check for unnecessary passes and other unnecessary constructs. + "ANN001", # Ensure type annotations are present where needed. + "ANN002", + "ANN003", + "ANN401", + "INP", # Flag invalid Python patterns or usage. + "PTH", # Check path-like or import-related issues. + "F", # Enforce function-level checks (e.g., complexity, arguments). + "B", # Enforce best practices for Python coding (general style rules). + "PT", # Enforce code formatting and Pythonic idioms. + "W", # Enforce warnings (e.g., suspicious constructs or behaviours). + "A", # Flag common anti-patterns or bad practices. + "RUF", # Ruff-specific rules. + "ARG", # Check for function argument issues., +] + +# Avoid enforcing line-length violations (`E501`) +ignore = ["E501", "RUF003"] + +# Avoid trying to fix flake8-bugbear (`B`) violations. +unfixable = ["B"] + +# Ignore `E402` (import violations) in all `__init__.py` files, and in selected subdirectories. +[tool.ruff.lint.per-file-ignores] +"__init__.py" = ["E402"] +"**/{tests,docs,tools}/*" = ["E402", "ANN", "INP001"] + +[tool.ruff.lint.flake8-annotations] +suppress-none-returning = true +mypy-init-return = true + +[tool.pyright] +include = ["src", "tests"] +exclude = ["tests/input", "tests/_input*", "tests/temp_dir"] + +disableBytesTypePromotions = true +reportAttributeAccessIssue = false +reportPropertyTypeMismatch = true +reportFunctionMemberAccess = true +reportMissingImports = true +reportUnusedVariable = "warning" +reportDuplicateImport = "warning" +reportUntypedFunctionDecorator = true +reportUntypedClassDecorator = true +reportUntypedBaseClass = true +reportUntypedNamedTuple = true +reportPrivateUsage = true +reportConstantRedefinition = "warning" +reportDeprecated = "warning" +reportIncompatibleMethodOverride = true +reportIncompatibleVariableOverride = true +reportInconsistentConstructor = true +reportOverlappingOverload = true +reportMissingTypeArgument = true +reportCallInDefaultInitializer = "warning" +reportUnnecessaryIsInstance = "warning" +reportUnnecessaryCast = "warning" +reportUnnecessaryComparison = true +reportMatchNotExhaustive = "warning" diff --git a/src/analyzers/base_analyzer.py b/src/analyzers/base_analyzer.py deleted file mode 100644 index cad46036..00000000 --- a/src/analyzers/base_analyzer.py +++ /dev/null @@ -1,9 +0,0 @@ -from abc import ABC, abstractmethod - -class BaseAnalyzer(ABC): - def __init__(self, code_path: str): - self.code_path = code_path - - @abstractmethod - def analyze(self): - pass diff --git a/src/analyzers/pylint_analyzer.py b/src/analyzers/pylint_analyzer.py deleted file mode 100644 index c8675a50..00000000 --- a/src/analyzers/pylint_analyzer.py +++ /dev/null @@ -1,70 +0,0 @@ -import subprocess -import json -from analyzers.base_analyzer import BaseAnalyzer - -class PylintAnalyzer(BaseAnalyzer): - def __init__(self, code_path: str): - super().__init__(code_path) - self.code_smells = { - "R0902": "Large Class", # Too many instance attributes - "R0913": "Long Parameter List", # Too many arguments - "R0915": "Long Method", # Too many statements - "C0200": "Complex List Comprehension", # Loop can be simplified - "C0103": "Invalid Naming Convention", # Non-standard names - # Add other pylint codes as needed - } - - def analyze(self): - """ - Runs Pylint on the specified code path and returns a report of code smells. - """ - pylint_command = [ - "pylint", "--output-format=json", self.code_path - ] - - try: - result = subprocess.run(pylint_command, capture_output=True, text=True, check=True) - pylint_output = result.stdout - report = self._parse_pylint_output(pylint_output) - return report - except subprocess.CalledProcessError as e: - print("Pylint analysis failed:", e) - return {} - except FileNotFoundError: - print("Pylint is not installed or not found in PATH.") - return {} - except json.JSONDecodeError: - print("Failed to parse pylint output. Check if pylint output is in JSON format.") - return {} - - def _parse_pylint_output(self, output: str): - """ - Parses the Pylint JSON output to identify specific code smells. - """ - try: - pylint_results = json.loads(output) - except json.JSONDecodeError: - print("Error: Failed to parse pylint output") - return [] - - code_smell_report = [] - - for entry in pylint_results: - message_id = entry.get("message-id") - if message_id in self.code_smells: - code_smell_report.append({ - "type": self.code_smells[message_id], - "message": entry.get("message"), - "line": entry.get("line"), - "column": entry.get("column"), - "path": entry.get("path") - }) - - return code_smell_report - -# Example usage -if __name__ == "__main__": - analyzer = PylintAnalyzer("your_file.py") - report = analyzer.analyze() - for issue in report: - print(f"{issue['type']} at {issue['path']}:{issue['line']}:{issue['column']} - {issue['message']}") diff --git a/src/README.md b/src/ecooptimizer/README.md similarity index 100% rename from src/README.md rename to src/ecooptimizer/README.md diff --git a/src/ecooptimizer/__init__.py b/src/ecooptimizer/__init__.py new file mode 100644 index 00000000..493243ca --- /dev/null +++ b/src/ecooptimizer/__init__.py @@ -0,0 +1,9 @@ +# Path of current directory +from pathlib import Path + +DIRNAME = Path(__file__).parent + +# Entire project directory path +SAMPLE_PROJ_DIR = (DIRNAME / Path("../../tests/input/project_car_stuff")).resolve() +SOURCE = SAMPLE_PROJ_DIR / "main.py" +TEST_FILE = SAMPLE_PROJ_DIR / "test_main.py" diff --git a/src/ecooptimizer/__main__.py b/src/ecooptimizer/__main__.py new file mode 100644 index 00000000..bbe683c2 --- /dev/null +++ b/src/ecooptimizer/__main__.py @@ -0,0 +1,132 @@ +import ast +import logging +from pathlib import Path +import shutil +from tempfile import TemporaryDirectory, mkdtemp # noqa: F401 + +import libcst as cst + +from .utils.output_manager import LoggingManager +from .utils.output_manager import save_file, save_json_files, copy_file_to_output + + +from .api.routes.refactor_smell import ChangedFile, RefactoredData + +from .measurements.codecarbon_energy_meter import CodeCarbonEnergyMeter + +from .analyzers.analyzer_controller import AnalyzerController + +from .refactorers.refactorer_controller import RefactorerController + +from . import ( + SAMPLE_PROJ_DIR, + SOURCE, +) + +from .config import CONFIG + +loggingManager = LoggingManager() + +CONFIG["loggingManager"] = loggingManager + +detect_logger = loggingManager.loggers["detect"] +refactor_logger = loggingManager.loggers["refactor"] + +CONFIG["detectLogger"] = detect_logger +CONFIG["refactorLogger"] = refactor_logger + + +# FILE CONFIGURATION IN __init__.py !!! + + +def main(): + # Save ast + save_file("source_ast.txt", ast.dump(ast.parse(SOURCE.read_text()), indent=4), "w") + save_file("source_cst.txt", str(cst.parse_module(SOURCE.read_text())), "w") + + # Measure initial energy + energy_meter = CodeCarbonEnergyMeter() + energy_meter.measure_energy(Path(SOURCE)) + initial_emissions = energy_meter.emissions + + if not initial_emissions: + logging.error("Could not retrieve initial emissions. Exiting.") + exit(1) + + analyzer_controller = AnalyzerController() + # update_smell_registry(["no-self-use"]) + smells_data = analyzer_controller.run_analysis(SOURCE) + save_json_files("code_smells.json", [smell.model_dump() for smell in smells_data]) + + copy_file_to_output(SOURCE, "refactored-test-case.py") + refactorer_controller = RefactorerController() + output_paths = [] + + for smell in smells_data: + # Use the line below and comment out "with TemporaryDirectory()" if you want to see the refactored code + # It basically copies the source directory into a temp dir that you can find in your systems TEMP folder + # It varies per OS. The location of the folder can be found in the 'refactored-data.json' file in outputs. + # If you use the other line know that you will have to manually delete the temp dir after running the + # code. It will NOT auto delete which, hence allowing you to see the refactoring results + + # tempDir = mkdtemp(prefix="ecooptimizer-") # < UNCOMMENT THIS LINE and shift code under to the left + + with TemporaryDirectory() as tempDir: # COMMENT OUT THIS ONE + source_copy = Path(tempDir) / SAMPLE_PROJ_DIR.name + target_file_copy = Path(str(SOURCE).replace(str(SAMPLE_PROJ_DIR), str(source_copy), 1)) + + # source_copy = project_copy / SOURCE.name + + shutil.copytree(SAMPLE_PROJ_DIR, source_copy) + + try: + modified_files: list[Path] = refactorer_controller.run_refactorer( + target_file_copy, source_copy, smell, overwrite=False + ) + except NotImplementedError as e: + print(e) + continue + + energy_meter.measure_energy(target_file_copy) + final_emissions = energy_meter.emissions + + if not final_emissions: + refactor_logger.error("Could not retrieve final emissions. Discarding refactoring.") + print("Refactoring Failed.\n") + + elif final_emissions >= initial_emissions: + refactor_logger.info("No measured energy savings. Discarding refactoring.\n") + print("Refactoring Failed.\n") + + else: + refactor_logger.info("Energy saved!") + refactor_logger.info( + f"Initial emissions: {initial_emissions} | Final emissions: {final_emissions}" + ) + + print("Refactoring Succesful!\n") + + refactor_data = RefactoredData( + tempDir=tempDir, + targetFile=ChangedFile(original=str(SOURCE), refactored=str(target_file_copy)), + energySaved=(final_emissions - initial_emissions), + affectedFiles=[ + ChangedFile( + original=str(file).replace(str(source_copy), str(SAMPLE_PROJ_DIR)), + refactored=str(file), + ) + for file in modified_files + ], + ) + + output_paths = refactor_data.affectedFiles + + # In reality the original code will now be overwritten but thats too much work + + save_json_files("refactoring-data.json", refactor_data.model_dump()) # type: ignore + + print(output_paths) + + +if __name__ == "__main__": + main() diff --git a/src/measurement/__init__.py b/src/ecooptimizer/analyzers/__init__.py similarity index 100% rename from src/measurement/__init__.py rename to src/ecooptimizer/analyzers/__init__.py diff --git a/src/ecooptimizer/analyzers/analyzer_controller.py b/src/ecooptimizer/analyzers/analyzer_controller.py new file mode 100644 index 00000000..65835b0c --- /dev/null +++ b/src/ecooptimizer/analyzers/analyzer_controller.py @@ -0,0 +1,137 @@ +# pyright: reportOptionalMemberAccess=false +from pathlib import Path +from typing import Callable, Any + +from ..data_types.smell_record import SmellRecord + +from ..config import CONFIG + +from ..data_types.smell import Smell + +from .pylint_analyzer import PylintAnalyzer +from .ast_analyzer import ASTAnalyzer +from .astroid_analyzer import AstroidAnalyzer + +from ..utils.smells_registry import retrieve_smell_registry + + +class AnalyzerController: + def __init__(self): + """Initializes analyzers for different analysis methods.""" + self.pylint_analyzer = PylintAnalyzer() + self.ast_analyzer = ASTAnalyzer() + self.astroid_analyzer = AstroidAnalyzer() + + def run_analysis(self, file_path: Path, selected_smells: str | list[str] = "ALL"): + """ + Runs multiple analysis tools on the given Python file and logs the results. + Returns a list of detected code smells. + """ + + smells_data: list[Smell] = [] + + if not selected_smells: + raise TypeError("At least 1 smell must be selected for detection") + + SMELL_REGISTRY = retrieve_smell_registry(selected_smells) + + try: + pylint_smells = self.filter_smells_by_method(SMELL_REGISTRY, "pylint") + ast_smells = self.filter_smells_by_method(SMELL_REGISTRY, "ast") + astroid_smells = self.filter_smells_by_method(SMELL_REGISTRY, "astroid") + + CONFIG["detectLogger"].info("🟒 Starting analysis process") + CONFIG["detectLogger"].info(f"πŸ“‚ Analyzing file: {file_path}") + + if pylint_smells: + CONFIG["detectLogger"].info(f"πŸ” Running Pylint analysis on {file_path}") + pylint_options = self.generate_pylint_options(pylint_smells) + pylint_results = self.pylint_analyzer.analyze(file_path, pylint_options) + smells_data.extend(pylint_results) + CONFIG["detectLogger"].info( + f"βœ… Pylint analysis completed. {len(pylint_results)} smells detected." + ) + + if ast_smells: + CONFIG["detectLogger"].info(f"πŸ” Running AST analysis on {file_path}") + ast_options = self.generate_custom_options(ast_smells) + ast_results = self.ast_analyzer.analyze(file_path, ast_options) + smells_data.extend(ast_results) + CONFIG["detectLogger"].info( + f"βœ… AST analysis completed. {len(ast_results)} smells detected." + ) + + if astroid_smells: + CONFIG["detectLogger"].info(f"πŸ” Running Astroid analysis on {file_path}") + astroid_options = self.generate_custom_options(astroid_smells) + astroid_results = self.astroid_analyzer.analyze(file_path, astroid_options) + smells_data.extend(astroid_results) + CONFIG["detectLogger"].info( + f"βœ… Astroid analysis completed. {len(astroid_results)} smells detected." + ) + + if smells_data: + CONFIG["detectLogger"].info("⚠️ Detected Code Smells:") + for smell in smells_data: + if smell.occurences: + first_occurrence = smell.occurences[0] + total_occurrences = len(smell.occurences) + line_info = ( + f"(Starting at Line {first_occurrence.line}, {total_occurrences} occurrences)" + if total_occurrences > 1 + else f"(Line {first_occurrence.line})" + ) + else: + line_info = "" + + CONFIG["detectLogger"].info(f" β€’ {smell.symbol} {line_info}: {smell.message}") + else: + CONFIG["detectLogger"].info("πŸŽ‰ No code smells detected.") + + except Exception as e: + CONFIG["detectLogger"].error(f"❌ Error during analysis: {e!s}") + + return smells_data + + @staticmethod + def filter_smells_by_method( + smell_registry: dict[str, SmellRecord], method: str + ) -> dict[str, SmellRecord]: + filtered = { + name: smell + for name, smell in smell_registry.items() + if smell["enabled"] and (method == smell["analyzer_method"]) + } + return filtered + + @staticmethod + def generate_pylint_options(filtered_smells: dict[str, SmellRecord]) -> list[str]: + pylint_smell_symbols = [] + extra_pylint_options = [ + "--disable=all", + ] + + for symbol, smell in zip(filtered_smells.keys(), filtered_smells.values()): + pylint_smell_symbols.append(symbol) + + if len(smell["analyzer_options"]) > 0: + for param_data in smell["analyzer_options"].values(): + flag = param_data["flag"] + value = param_data["value"] + if value: + extra_pylint_options.append(f"{flag}={value}") + + extra_pylint_options.append(f"--enable={','.join(pylint_smell_symbols)}") + return extra_pylint_options + + @staticmethod + def generate_custom_options( + filtered_smells: dict[str, SmellRecord], + ) -> list[tuple[Callable, dict[str, Any]]]: # type: ignore + ast_options = [] + for smell in filtered_smells.values(): + method = smell["checker"] + options = smell["analyzer_options"] + ast_options.append((method, options)) + + return ast_options diff --git a/src/ecooptimizer/analyzers/ast_analyzer.py b/src/ecooptimizer/analyzers/ast_analyzer.py new file mode 100644 index 00000000..e9c0b051 --- /dev/null +++ b/src/ecooptimizer/analyzers/ast_analyzer.py @@ -0,0 +1,27 @@ +from typing import Callable, Any +from pathlib import Path +from ast import AST, parse + + +from .base_analyzer import Analyzer +from ..data_types.smell import Smell + + +class ASTAnalyzer(Analyzer): + def analyze( + self, + file_path: Path, + extra_options: list[tuple[Callable[[Path, AST], list[Smell]], dict[str, Any]]], + ): + smells_data: list[Smell] = [] + + source_code = file_path.read_text() + + tree = parse(source_code) + + for detector, params in extra_options: + if callable(detector): + result = detector(file_path, tree, **params) + smells_data.extend(result) + + return smells_data diff --git a/src/refactorer/__init__.py b/src/ecooptimizer/analyzers/ast_analyzers/__init__.py similarity index 100% rename from src/refactorer/__init__.py rename to src/ecooptimizer/analyzers/ast_analyzers/__init__.py diff --git a/src/ecooptimizer/analyzers/ast_analyzers/detect_long_element_chain.py b/src/ecooptimizer/analyzers/ast_analyzers/detect_long_element_chain.py new file mode 100644 index 00000000..3fa39d86 --- /dev/null +++ b/src/ecooptimizer/analyzers/ast_analyzers/detect_long_element_chain.py @@ -0,0 +1,73 @@ +import ast +from pathlib import Path + +from ...utils.smell_enums import CustomSmell + +from ...data_types.smell import LECSmell +from ...data_types.custom_fields import AdditionalInfo, Occurence + + +def detect_long_element_chain(file_path: Path, tree: ast.AST, threshold: int = 5) -> list[LECSmell]: + """ + Detects long element chains in the given Python code and returns a list of Smell objects. + + Args: + file_path (Path): The file path to analyze. + tree (ast.AST): The Abstract Syntax Tree (AST) of the source code. + threshold (int): The minimum length of a dictionary chain. Default is 3. + + Returns: + list[Smell]: A list of Smell objects, each containing details about a detected long chain. + """ + # Initialize an empty list to store detected Smell objects + results: list[LECSmell] = [] + used_lines = set() + + # Function to calculate the length of a dictionary chain and detect long chains + def check_chain(node: ast.Subscript, chain_length: int = 0): + # Ensure each line is only reported once + if node.lineno in used_lines: + return + + current = node + # Traverse through the chain to count its length + while isinstance(current, ast.Subscript): + chain_length += 1 + current = current.value + + print(chain_length) + if chain_length >= threshold: + # Create a descriptive message for the detected long chain + message = f"Dictionary chain too long ({chain_length}/{threshold})" + print(node.lineno) + # Instantiate a Smell object with details about the detected issue + smell = LECSmell( + path=str(file_path), + module=file_path.stem, + obj=None, + type="convention", + symbol="long-element-chain", + message=message, + messageId=CustomSmell.LONG_ELEMENT_CHAIN.value, + confidence="UNDEFINED", + occurences=[ + Occurence( + line=node.lineno, + endLine=node.end_lineno, + column=node.col_offset, + endColumn=node.end_col_offset, + ) + ], + additionalInfo=AdditionalInfo(), + ) + + used_lines.add(node.lineno) + results.append(smell) + + # Traverse the AST to identify nodes representing dictionary chains + for node in ast.walk(tree): + if isinstance(node, ast.Subscript): + check_chain(node) + + # Return the list of detected Smell objects + return results diff --git a/src/ecooptimizer/analyzers/ast_analyzers/detect_long_lambda_expression.py b/src/ecooptimizer/analyzers/ast_analyzers/detect_long_lambda_expression.py new file mode 100644 index 00000000..2ff0fccb --- /dev/null +++ b/src/ecooptimizer/analyzers/ast_analyzers/detect_long_lambda_expression.py @@ -0,0 +1,152 @@ +import ast +from pathlib import Path + +from ...utils.smell_enums import CustomSmell + +from ...data_types.smell import LLESmell +from ...data_types.custom_fields import AdditionalInfo, Occurence + + +def count_expressions(node: ast.expr) -> int: + """ + Recursively counts the number of sub-expressions inside a lambda body. + Ensures `sum()` only operates on integers. + """ + if isinstance(node, (ast.BinOp, ast.BoolOp, ast.Compare, ast.Call, ast.IfExp)): + return 1 + sum( + count_expressions(child) + for child in ast.iter_child_nodes(node) + if isinstance(child, ast.expr) + ) + + # Ensure all recursive calls return an integer + return sum( + ( + count_expressions(child) + for child in ast.iter_child_nodes(node) + if isinstance(child, ast.expr) + ), + start=0, + ) + + +# Helper function to get the string representation of the lambda expression +def get_lambda_code(lambda_node: ast.Lambda) -> str: + """ + Constructs the string representation of a lambda expression. + + Args: + lambda_node (ast.Lambda): The lambda node to reconstruct. + + Returns: + str: The string representation of the lambda expression. + """ + # Reconstruct the lambda arguments and body as a string + args = ", ".join(arg.arg for arg in lambda_node.args.args) + + # Convert the body to a string by using ast's built-in functionality + body = ast.unparse(lambda_node.body) + + # Combine to form the lambda expression + return f"lambda {args}: {body}" + + +def detect_long_lambda_expression( + file_path: Path, + tree: ast.AST, + threshold_length: int = 100, + threshold_count: int = 5, +) -> list[LLESmell]: + """ + Detects lambda functions that are too long, either by the number of expressions or the total length in characters. + + Args: + file_path (Path): The file path to analyze. + tree (ast.AST): The Abstract Syntax Tree (AST) of the source code. + threshold_length (int): The maximum number of characters allowed in the lambda expression. + threshold_count (int): The maximum number of expressions allowed inside the lambda function. + + Returns: + list[Smell]: A list of Smell objects, each containing details about detected long lambda functions. + """ + # Initialize an empty list to store detected Smell objects + results: list[LLESmell] = [] + used_lines = set() + + # Function to check the length of lambda expressions + def check_lambda(node: ast.Lambda): + """ + Analyzes a lambda node to check if it exceeds the specified thresholds + for the number of expressions or total character length. + + Args: + node (ast.Lambda): The lambda node to analyze. + """ + # Count the number of expressions in the lambda body + lambda_length = count_expressions(node.body) + + # Check if the lambda expression exceeds the threshold based on the number of expressions + if lambda_length >= threshold_count: + message = f"Lambda function too long ({lambda_length}/{threshold_count} expressions)" + # Initialize the Smell instance + smell = LLESmell( + path=str(file_path), + module=file_path.stem, + obj=None, + type="convention", + symbol="long-lambda-expression", + message=message, + messageId=CustomSmell.LONG_LAMBDA_EXPR.value, + confidence="UNDEFINED", + occurences=[ + Occurence( + line=node.lineno, + endLine=node.end_lineno, + column=node.col_offset, + endColumn=node.end_col_offset, + ) + ], + additionalInfo=AdditionalInfo(), + ) + + if node.lineno in used_lines: + return + used_lines.add(node.lineno) + results.append(smell) + + # Convert the lambda function to a string and check its total length in characters + lambda_code = get_lambda_code(node) + if len(lambda_code) > threshold_length: + message = f"Lambda function too long ({len(lambda_code)} characters, max {threshold_length})" + smell = LLESmell( + path=str(file_path), + module=file_path.stem, + obj=None, + type="convention", + symbol="long-lambda-expr", + message=message, + messageId=CustomSmell.LONG_LAMBDA_EXPR.value, + confidence="UNDEFINED", + occurences=[ + Occurence( + line=node.lineno, + endLine=node.end_lineno, + column=node.col_offset, + endColumn=node.end_col_offset, + ) + ], + additionalInfo=AdditionalInfo(), + ) + + if node.lineno in used_lines: + return + used_lines.add(node.lineno) + results.append(smell) + + # Walk through the AST to find lambda expressions + for node in ast.walk(tree): + if isinstance(node, ast.Lambda): + check_lambda(node) + + # Return the list of detected Smell objects + return results diff --git a/src/ecooptimizer/analyzers/ast_analyzers/detect_long_message_chain.py b/src/ecooptimizer/analyzers/ast_analyzers/detect_long_message_chain.py new file mode 100644 index 00000000..b3d59c73 --- /dev/null +++ b/src/ecooptimizer/analyzers/ast_analyzers/detect_long_message_chain.py @@ -0,0 +1,85 @@ +import ast +from pathlib import Path + +from ...utils.smell_enums import CustomSmell + +from ...data_types.smell import LMCSmell +from ...data_types.custom_fields import AdditionalInfo, Occurence + + +def compute_chain_length(node: ast.expr) -> int: + """ + Recursively determines how many consecutive calls exist in a chain + ending at 'node'. Each .something() is +1. + """ + if isinstance(node, ast.Call): + # We have a call, so that's +1 + if isinstance(node.func, ast.Attribute): + # The chain might continue if node.func.value is also a call + return 1 + compute_chain_length(node.func.value) + else: + return 1 + elif isinstance(node, ast.Attribute): + # If it's just an attribute (like `details` or `obj.x`), + # we keep looking up the chain but *don’t increment*, + # because we only count calls. + return compute_chain_length(node.value) + else: + # If it's a Name or something else, we stop + return 0 + + +def detect_long_message_chain( + file_path: Path, tree: ast.AST, threshold: int = 5 +) -> list[LMCSmell]: + """ + Detects long message chains in the given Python code. + + Args: + file_path (Path): The file path to analyze. + tree (ast.AST): The Abstract Syntax Tree (AST) of the source code. + threshold (int): The minimum number of chained method calls to flag as a long chain. Default is 5. + + Returns: + list[Smell]: A list of Smell objects, each containing details about the detected long chains. + """ + # Initialize an empty list to store detected Smell objects + results: list[LMCSmell] = [] + used_lines = set() + + # Walk through the AST to find method calls and attribute chains + for node in ast.walk(tree): + # Check only method calls (Call node whose func is an Attribute) + if isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute): + length = compute_chain_length(node) + if length >= threshold: + line = node.lineno + # Make sure we haven’t already reported on this line + if line not in used_lines: + used_lines.add(line) + + message = f"Method chain too long ({length}/{threshold})" + # Create the smell object + smell = LMCSmell( + path=str(file_path), + module=file_path.stem, + obj=None, + type="convention", + symbol="long-message-chain", + message=message, + messageId=CustomSmell.LONG_MESSAGE_CHAIN.value, + confidence="UNDEFINED", + occurences=[ + Occurence( + line=node.lineno, + endLine=node.end_lineno, + column=node.col_offset, + endColumn=node.end_col_offset, + ) + ], + additionalInfo=AdditionalInfo(), + ) + results.append(smell) + + # Return the list of detected Smell objects + return results diff --git a/src/ecooptimizer/analyzers/ast_analyzers/detect_repeated_calls.py b/src/ecooptimizer/analyzers/ast_analyzers/detect_repeated_calls.py new file mode 100644 index 00000000..6764ad7b --- /dev/null +++ b/src/ecooptimizer/analyzers/ast_analyzers/detect_repeated_calls.py @@ -0,0 +1,145 @@ +import ast +from collections import defaultdict +from pathlib import Path +import astor + +from ...data_types.custom_fields import CRCInfo, Occurence +from ...data_types.smell import CRCSmell +from ...utils.smell_enums import CustomSmell + + +IGNORED_PRIMITIVE_BUILTINS = {"abs", "round"} # Built-ins safe to ignore when used with primitives +IGNORED_CONSTRUCTORS = {"set", "list", "dict", "tuple"} # Constructors to ignore +EXPENSIVE_BUILTINS = { + "max", + "sum", + "sorted", + "min", +} # Built-ins to track when argument is non-primitive + + +def is_primitive_expression(node: ast.AST): + """Returns True if the AST node is a primitive (int, float, str, bool), including negative numbers.""" + if isinstance(node, ast.Constant) and isinstance(node.value, (int, float, str, bool)): + return True + if ( + isinstance(node, ast.UnaryOp) + and isinstance(node.op, (ast.UAdd, ast.USub)) + and isinstance(node.operand, ast.Constant) + ): + return isinstance(node.operand.value, (int, float)) + return False + + +def detect_repeated_calls(file_path: Path, tree: ast.AST, threshold: int = 2): + results: list[CRCSmell] = [] + + source_code = file_path.read_text() + + def match_quote_style(source: str, function_call: str): + """Detect whether the function call uses single or double quotes in the source.""" + if function_call.replace('"', "'") in source: + return "'" + return '"' + + for node in ast.walk(tree): + if isinstance(node, (ast.FunctionDef, ast.For, ast.While)): + call_counts: dict[str, list[ast.Call]] = defaultdict(list) + assigned_calls = set() + modified_objects = {} + call_lines = {} + + # Track assignments (only calls assigned to a variable should be considered) + for subnode in ast.walk(node): + if isinstance(subnode, ast.Assign) and isinstance(subnode.value, ast.Call): + call_repr = astor.to_source(subnode.value).strip() + assigned_calls.add(call_repr) + + # Track object attribute modifications (e.g., obj.value = 10) + for subnode in ast.walk(node): + if isinstance(subnode, ast.Assign) and isinstance( + subnode.targets[0], ast.Attribute + ): + obj_name = astor.to_source(subnode.targets[0].value).strip() + modified_objects[obj_name] = subnode.lineno + + # Track function calls + for subnode in ast.walk(node): + if isinstance(subnode, ast.Call): + raw_call_string = astor.to_source(subnode).strip() + call_line = subnode.lineno + + preferred_quote = match_quote_style(source_code, raw_call_string) + callString = raw_call_string.replace("'", preferred_quote).replace( + '"', preferred_quote + ) + + # Ignore built-in functions when their argument is a primitive + if isinstance(subnode.func, ast.Name): + func_name = subnode.func.id + + if func_name in IGNORED_CONSTRUCTORS: + continue + + if func_name in IGNORED_PRIMITIVE_BUILTINS: + if len(subnode.args) == 1 and is_primitive_expression(subnode.args[0]): + continue + + if func_name in EXPENSIVE_BUILTINS: + if len(subnode.args) == 1 and not is_primitive_expression( + subnode.args[0] + ): + call_counts[callString].append(subnode) + continue + + # Check if it's a class by looking for capitalized names (heuristic) + if func_name[0].isupper(): + continue + + obj_name = ( + astor.to_source(subnode.func.value).strip() + if isinstance(subnode.func, ast.Attribute) + else None + ) + + if obj_name: + if obj_name in modified_objects and modified_objects[obj_name] < call_line: + continue + + if raw_call_string in assigned_calls: + call_counts[raw_call_string].append(subnode) + call_lines[raw_call_string] = call_line + + # Identify repeated calls + for callString, occurrences in call_counts.items(): + if len(occurrences) >= threshold: + preferred_quote = match_quote_style(source_code, callString) + normalized_callString = callString.replace("'", preferred_quote).replace( + '"', preferred_quote + ) + + smell = CRCSmell( + path=str(file_path), + type="performance", + obj=None, + module=file_path.stem, + symbol="cached-repeated-calls", + message=f"Repeated function call detected ({len(occurrences)}/{threshold}). Consider caching the result: {normalized_callString}", + messageId=CustomSmell.CACHE_REPEATED_CALLS.value, + confidence="HIGH" if len(occurrences) > threshold else "MEDIUM", + occurences=[ + Occurence( + line=occ.lineno, + endLine=occ.end_lineno, + column=occ.col_offset, + endColumn=occ.end_col_offset, + ) + for occ in occurrences + ], + additionalInfo=CRCInfo( + repetitions=len(occurrences), callString=normalized_callString + ), + ) + results.append(smell) + + return results diff --git a/src/ecooptimizer/analyzers/astroid_analyzer.py b/src/ecooptimizer/analyzers/astroid_analyzer.py new file mode 100644 index 00000000..e2622c4d --- /dev/null +++ b/src/ecooptimizer/analyzers/astroid_analyzer.py @@ -0,0 +1,32 @@ +from typing import Callable, Any +from pathlib import Path +from astroid import nodes, parse + + +from .base_analyzer import Analyzer +from ..data_types.smell import Smell + + +class AstroidAnalyzer(Analyzer): + def analyze( + self, + file_path: Path, + extra_options: list[ + tuple[ + Callable[[Path, nodes.Module], list[Smell]], + dict[str, Any], + ] + ], + ): + smells_data: list[Smell] = [] + + source_code = file_path.read_text() + + tree = parse(source_code) + + for detector, params in extra_options: + if callable(detector): + result = detector(file_path, tree, **params) + smells_data.extend(result) + + return smells_data diff --git a/src/testing/__init__.py b/src/ecooptimizer/analyzers/astroid_analyzers/__init__.py similarity index 100% rename from src/testing/__init__.py rename to src/ecooptimizer/analyzers/astroid_analyzers/__init__.py diff --git a/src/ecooptimizer/analyzers/astroid_analyzers/detect_string_concat_in_loop.py b/src/ecooptimizer/analyzers/astroid_analyzers/detect_string_concat_in_loop.py new file mode 100644 index 00000000..442c6452 --- /dev/null +++ b/src/ecooptimizer/analyzers/astroid_analyzers/detect_string_concat_in_loop.py @@ -0,0 +1,266 @@ +from pathlib import Path +import re +from astroid import nodes, util, parse, AttributeInferenceError + +from ...data_types.custom_fields import Occurence, SCLInfo +from ...data_types.smell import SCLSmell +from ...utils.smell_enums import CustomSmell + + +def detect_string_concat_in_loop(file_path: Path, tree: nodes.Module): + """ + Detects string concatenation inside loops within a Python AST tree. + + Parameters: + file_path (Path): The file path to analyze. + tree (nodes.Module): The parsed AST tree of the Python code. + + Returns: + list[dict]: A list of dictionaries containing details about detected string concatenation smells. + """ + smells: list[SCLSmell] = [] + in_loop_counter = 0 + current_loops: list[nodes.NodeNG] = [] + # current_semlls = { var_name : ( index of smell, index of loop )} + current_smells: dict[str, tuple[int, int]] = {} + + def create_smell(node: nodes.Assign): + nonlocal current_loops, current_smells + + if node.lineno and node.col_offset: + smells.append( + SCLSmell( + path=str(file_path), + module=file_path.name, + obj=None, + type="performance", + symbol="string-concat-loop", + message="String concatenation inside loop detected", + messageId=CustomSmell.STR_CONCAT_IN_LOOP.value, + confidence="UNDEFINED", + occurences=[create_smell_occ(node)], + additionalInfo=SCLInfo( + innerLoopLine=current_loops[ + current_smells[node.targets[0].as_string()][1] + ].lineno, # type: ignore + concatTarget=node.targets[0].as_string(), + ), + ) + ) + + def create_smell_occ(node: nodes.Assign | nodes.AugAssign) -> Occurence: + return Occurence( + line=node.lineno, # type: ignore + endLine=node.end_lineno, + column=node.col_offset, # type: ignore + endColumn=node.end_col_offset, + ) + + def visit(node: nodes.NodeNG): + nonlocal smells, in_loop_counter, current_loops, current_smells + + if isinstance(node, (nodes.For, nodes.While)): + in_loop_counter += 1 + current_loops.append(node) + for stmt in node.body: + visit(stmt) + + in_loop_counter -= 1 + + current_smells = { + key: val for key, val in current_smells.items() if val[1] != in_loop_counter + } + current_loops.pop() + + elif in_loop_counter > 0 and isinstance(node, nodes.Assign): + target = None + value = None + + if len(node.targets) == 1 > 1: + return + + target = node.targets[0] + value = node.value + + if target and isinstance(value, nodes.BinOp) and value.op == "+": + if ( + target.as_string() not in current_smells + and is_string_type(node) + and is_concatenating_with_self(value, target) + and is_not_referenced(node) + ): + current_smells[target.as_string()] = ( + len(smells), + in_loop_counter - 1, + ) + create_smell(node) + elif target.as_string() in current_smells and is_concatenating_with_self( + value, target + ): + smell_id = current_smells[target.as_string()][0] + smells[smell_id].occurences.append(create_smell_occ(node)) + else: + for child in node.get_children(): + visit(child) + + def is_not_referenced(node: nodes.Assign): + nonlocal current_loops + + loop_source_str = current_loops[-1].as_string() + loop_source_str = loop_source_str.replace(node.as_string(), "", 1) + lines = loop_source_str.splitlines() + for line in lines: + if ( + line.find(node.targets[0].as_string()) != -1 + and re.search(rf"\b{re.escape(node.targets[0].as_string())}\b\s*=", line) is None + ): + return False + return True + + def is_concatenating_with_self(binop_node: nodes.BinOp, target: nodes.NodeNG): + """Check if the BinOp node includes the target variable being added.""" + + def is_same_variable(var1: nodes.NodeNG, var2: nodes.NodeNG): + if isinstance(var1, nodes.Name) and isinstance(var2, nodes.AssignName): + return var1.name == var2.name + if isinstance(var1, nodes.Attribute) and isinstance(var2, nodes.AssignAttr): + return var1.as_string() == var2.as_string() + if isinstance(var1, nodes.Subscript) and isinstance(var2, nodes.Subscript): + if isinstance(var1.slice, nodes.Const) and isinstance(var2.slice, nodes.Const): + return var1.as_string() == var2.as_string() + if isinstance(var1, nodes.BinOp) and var1.op == "+": + return is_same_variable(var1.left, target) or is_same_variable(var1.right, target) + return False + + left, right = binop_node.left, binop_node.right + return is_same_variable(left, target) or is_same_variable(right, target) + + def is_string_type(node: nodes.Assign) -> bool: + target = node.targets[0] + + # Check type hints first + if has_type_hints_str(node, target): + return True + + # Infer types + for inferred in target.infer(): + if inferred.repr_name() == "str": + return True + if isinstance(inferred, util.UninferableBase): + print(f"here: {node}") + if has_str_format(node.value) or has_str_interpolation(node.value): + return True + for var in node.value.nodes_of_class( + (nodes.Name, nodes.Attribute, nodes.Subscript) + ): + if var.as_string() == target.as_string(): + for inferred_target in var.infer(): + if inferred_target.repr_name() == "str": + return True + + print(f"Checking type hints for {var}") + if has_type_hints_str(node, var): + return True + + return False + + def has_type_hints_str(context: nodes.NodeNG, target: nodes.NodeNG) -> bool: + """Checks if a variable has an explicit type hint for `str`""" + parent = context.scope() + + # Function argument type hints + if isinstance(parent, nodes.FunctionDef) and parent.args.args: + for arg, ann in zip(parent.args.args, parent.args.annotations): + print(f"arg: {arg}, target: {target}, ann: {ann}") + if arg.name == target.as_string() and ann and ann.as_string() == "str": + return True + + # Class attributes (annotations in class scope or __init__) + if "self." in target.as_string(): + class_def = parent.frame() + if not isinstance(class_def, nodes.ClassDef): + class_def = next( + ( + ancestor + for ancestor in context.node_ancestors() + if isinstance(ancestor, nodes.ClassDef) + ), + None, + ) + + if class_def: + attr_name = target.as_string().replace("self.", "") + try: + for attr in class_def.instance_attr(attr_name): + if ( + isinstance(attr, nodes.AnnAssign) + and attr.annotation.as_string() == "str" + ): + return True + if any(inf.repr_name() == "str" for inf in attr.infer()): + return True + except AttributeInferenceError: + pass + + # Global/scope variable annotations before assignment + for child in parent.nodes_of_class((nodes.AnnAssign, nodes.Assign)): + if child == context: + break + if ( + isinstance(child, nodes.AnnAssign) + and child.target.as_string() == target.as_string() + ): + return child.annotation.as_string() == "str" + print("checking var types") + if isinstance(child, nodes.Assign) and is_string_type(child): + return True + + return False + + def has_str_format(node: nodes.NodeNG): + if isinstance(node, nodes.BinOp) and node.op == "+": + str_repr = node.as_string() + match = re.search("{.*}", str_repr) + if match: + return True + + return False + + def has_str_interpolation(node: nodes.NodeNG): + if isinstance(node, nodes.BinOp) and node.op == "+": + str_repr = node.as_string() + match = re.search("%[a-z]", str_repr) + if match: + return True + return False + + def transform_augassign_to_assign(code_file: str): + """ + Changes all AugAssign occurences to Assign in a code file. + + :param code_file: The source code file as a string + :return: The same string source code with all AugAssign stmts changed to Assign + """ + str_code = code_file.splitlines() + + for i in range(len(str_code)): + eq_col = str_code[i].find(" +=") + + if eq_col == -1: + continue + + target_var = str_code[i][0:eq_col].strip() + + # Replace '+=' with '=' to form an Assign string + str_code[i] = str_code[i].replace("+=", f"= {target_var} +", 1) + + return "\n".join(str_code) + + # Change all AugAssigns to Assigns + tree = parse(transform_augassign_to_assign(file_path.read_text())) + + # Start traversal + for child in tree.get_children(): + visit(child) + + return smells diff --git a/src/ecooptimizer/analyzers/base_analyzer.py b/src/ecooptimizer/analyzers/base_analyzer.py new file mode 100644 index 00000000..a20673f4 --- /dev/null +++ b/src/ecooptimizer/analyzers/base_analyzer.py @@ -0,0 +1,12 @@ +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any + + +from ..data_types.smell import Smell + + +class Analyzer(ABC): + @abstractmethod + def analyze(self, file_path: Path, extra_options: list[Any]) -> list[Smell]: + pass diff --git a/src/ecooptimizer/analyzers/pylint_analyzer.py b/src/ecooptimizer/analyzers/pylint_analyzer.py new file mode 100644 index 00000000..e11f2e22 --- /dev/null +++ b/src/ecooptimizer/analyzers/pylint_analyzer.py @@ -0,0 +1,61 @@ +from io import StringIO +import json +from pathlib import Path +from pylint.lint import Run +from pylint.reporters.json_reporter import JSON2Reporter + +from ..config import CONFIG + +from ..data_types.custom_fields import AdditionalInfo, Occurence + +from .base_analyzer import Analyzer +from ..data_types.smell import Smell + + +class PylintAnalyzer(Analyzer): + def _build_smells(self, pylint_smells: dict): # type: ignore + """Casts initial list of pylint smells to the Eco Optimizer's Smell configuration.""" + smells: list[Smell] = [] + + for smell in pylint_smells: + smells.append( + Smell( + confidence=smell["confidence"], + message=smell["message"], + messageId=smell["messageId"], + module=smell["module"], + obj=smell["obj"], + path=smell["absolutePath"], + symbol=smell["symbol"], + type=smell["type"], + occurences=[ + Occurence( + line=smell["line"], + endLine=smell["endLine"], + column=smell["column"], + endColumn=smell["endColumn"], + ) + ], + additionalInfo=AdditionalInfo(), + ) + ) + + return smells + + def analyze(self, file_path: Path, extra_options: list[str]): + smells_data: list[Smell] = [] + pylint_options = [str(file_path), *extra_options] + + with StringIO() as buffer: + reporter = JSON2Reporter(buffer) + + try: + Run(pylint_options, reporter=reporter, exit=False) + buffer.seek(0) + smells_data.extend(self._build_smells(json.loads(buffer.getvalue())["messages"])) + except json.JSONDecodeError as e: + CONFIG["detectLogger"].error(f"❌ Failed to parse JSON output from pylint: {e}") # type: ignore + except Exception as e: + CONFIG["detectLogger"].error(f"❌ An error occurred during pylint analysis: {e}") # type: ignore + + return smells_data diff --git a/src/utils/__init__.py b/src/ecooptimizer/api/__init__.py similarity index 100% rename from src/utils/__init__.py rename to src/ecooptimizer/api/__init__.py diff --git a/src/ecooptimizer/api/__main__.py b/src/ecooptimizer/api/__main__.py new file mode 100644 index 00000000..aa1f1713 --- /dev/null +++ b/src/ecooptimizer/api/__main__.py @@ -0,0 +1,57 @@ +import logging +import sys +import uvicorn + +from .app import app + +from ..config import CONFIG + + +class HealthCheckFilter(logging.Filter): + def filter(self, record: logging.LogRecord) -> bool: + return "/health" not in record.getMessage() + + +# Apply the filter to Uvicorn's access logger +logging.getLogger("uvicorn.access").addFilter(HealthCheckFilter()) + + +def start(): + # ANSI codes + RESET = "\u001b[0m" + BLUE = "\u001b[36m" + PURPLE = "\u001b[35m" + + mode_message = f"{CONFIG['mode'].upper()} MODE" + msg_len = len(mode_message) + + print(f"\n\t\t\t***{'*'*msg_len}***") + print(f"\t\t\t* {BLUE}{mode_message}{RESET} *") + print(f"\t\t\t***{'*'*msg_len}***\n") + if CONFIG["mode"] == "production": + print(f"{PURPLE}hint: add --dev flag at the end to ignore energy checks\n") + + logging.info("πŸš€ Running EcoOptimizer Application...") + logging.info(f"{'=' * 100}\n") + uvicorn.run( + app, + host="127.0.0.1", + port=8000, + log_level="info", + access_log=True, + timeout_graceful_shutdown=2, + ) + + +def main(): + CONFIG["mode"] = "development" if "--dev" in sys.argv else "production" + start() + + +def dev(): + CONFIG["mode"] = "development" + start() + + +if __name__ == "__main__": + main() diff --git a/src/ecooptimizer/api/app.py b/src/ecooptimizer/api/app.py new file mode 100644 index 00000000..bace8451 --- /dev/null +++ b/src/ecooptimizer/api/app.py @@ -0,0 +1,15 @@ +from fastapi import FastAPI +from .routes import RefactorRouter, DetectRouter, LogRouter + + +app = FastAPI(title="Ecooptimizer") + +# Include API routes +app.include_router(RefactorRouter) +app.include_router(DetectRouter) +app.include_router(LogRouter) + + +@app.get("/health") +async def ping(): + return {"status": "ok"} diff --git a/src/ecooptimizer/api/routes/__init__.py b/src/ecooptimizer/api/routes/__init__.py new file mode 100644 index 00000000..b0b59465 --- /dev/null +++ b/src/ecooptimizer/api/routes/__init__.py @@ -0,0 +1,5 @@ +from .refactor_smell import router as RefactorRouter +from .detect_smells import router as DetectRouter +from .show_logs import router as LogRouter + +__all__ = ["DetectRouter", "LogRouter", "RefactorRouter"] diff --git a/src/ecooptimizer/api/routes/detect_smells.py b/src/ecooptimizer/api/routes/detect_smells.py new file mode 100644 index 00000000..fb86357c --- /dev/null +++ b/src/ecooptimizer/api/routes/detect_smells.py @@ -0,0 +1,66 @@ +# pyright: reportOptionalMemberAccess=false +from pathlib import Path +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel +import time + +from ...config import CONFIG + +from ...analyzers.analyzer_controller import AnalyzerController +from ...data_types.smell import Smell + +router = APIRouter() + +analyzer_controller = AnalyzerController() + + +class SmellRequest(BaseModel): + file_path: str + enabled_smells: list[str] + + +@router.post("/smells", response_model=list[Smell]) +def detect_smells(request: SmellRequest): + """ + Detects code smells in a given file, logs the process, and measures execution time. + """ + + CONFIG["detectLogger"].info(f"{'=' * 100}") + CONFIG["detectLogger"].info(f"πŸ“‚ Received smell detection request for: {request.file_path}") + + start_time = time.time() + + try: + file_path_obj = Path(request.file_path) + + if not file_path_obj.exists(): + CONFIG["detectLogger"].error(f"❌ File does not exist: {file_path_obj}") + raise FileNotFoundError(f"File not found: {file_path_obj}") + + CONFIG["detectLogger"].debug( + f"πŸ”Ž Enabled smells: {', '.join(request.enabled_smells) if request.enabled_smells else 'None'}" + ) + + # Run analysis + CONFIG["detectLogger"].info(f"🎯 Running analysis on: {file_path_obj}") + smells_data = analyzer_controller.run_analysis(file_path_obj, request.enabled_smells) + + execution_time = round(time.time() - start_time, 2) + CONFIG["detectLogger"].info(f"πŸ“Š Execution Time: {execution_time} seconds") + + CONFIG["detectLogger"].info( + f"🏁 Analysis completed for {file_path_obj}. {len(smells_data)} smells found." + ) + CONFIG["detectLogger"].info(f"{'=' * 100}\n") + + return smells_data + + except FileNotFoundError as e: + CONFIG["detectLogger"].error(f"❌ File not found: {e}") + CONFIG["detectLogger"].info(f"{'=' * 100}\n") + raise HTTPException(status_code=404, detail=str(e)) from e + + except Exception as e: + CONFIG["detectLogger"].error(f"❌ Error during smell detection: {e!s}") + CONFIG["detectLogger"].info(f"{'=' * 100}\n") + raise HTTPException(status_code=500, detail="Internal server error") from e diff --git a/src/ecooptimizer/api/routes/refactor_smell.py b/src/ecooptimizer/api/routes/refactor_smell.py new file mode 100644 index 00000000..ae762401 --- /dev/null +++ b/src/ecooptimizer/api/routes/refactor_smell.py @@ -0,0 +1,192 @@ +# pyright: reportOptionalMemberAccess=false +import shutil +import math +from pathlib import Path +from tempfile import mkdtemp +import traceback +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel +from typing import Any, Optional + +from ...config import CONFIG +from ...analyzers.analyzer_controller import AnalyzerController +from ...exceptions import EnergySavingsError, RefactoringError, remove_readonly +from ...refactorers.refactorer_controller import RefactorerController +from ...measurements.codecarbon_energy_meter import CodeCarbonEnergyMeter +from ...data_types.smell import Smell + +router = APIRouter() +analyzer_controller = AnalyzerController() +refactorer_controller = RefactorerController() +energy_meter = CodeCarbonEnergyMeter() + + +class ChangedFile(BaseModel): + original: str + refactored: str + + +class RefactoredData(BaseModel): + tempDir: str + targetFile: ChangedFile + energySaved: Optional[float] = None + affectedFiles: list[ChangedFile] + + +class RefactorRqModel(BaseModel): + source_dir: str + smell: Smell + + +class RefactorResModel(BaseModel): + refactoredData: Optional[RefactoredData] = None + updatedSmells: list[Smell] + + +@router.post("/refactor", response_model=RefactorResModel) +def refactor(request: RefactorRqModel): + """Handles the refactoring process for a given smell.""" + CONFIG["refactorLogger"].info(f"{'=' * 100}") + CONFIG["refactorLogger"].info("πŸ”„ Received refactor request.") + + try: + CONFIG["refactorLogger"].info( + f"πŸ” Analyzing smell: {request.smell.symbol} in {request.source_dir}" + ) + refactor_data, updated_smells = perform_refactoring(Path(request.source_dir), request.smell) + + CONFIG["refactorLogger"].info( + f"βœ… Refactoring process completed. Updated smells: {len(updated_smells)}" + ) + + if refactor_data: + refactor_data = clean_refactored_data(refactor_data) + CONFIG["refactorLogger"].info(f"{'=' * 100}\n") + return RefactorResModel(refactoredData=refactor_data, updatedSmells=updated_smells) + + CONFIG["refactorLogger"].info(f"{'=' * 100}\n") + return RefactorResModel(updatedSmells=updated_smells) + + except OSError as e: + CONFIG["refactorLogger"].error(f"❌ OS error: {e!s}") + raise HTTPException(status_code=404, detail=str(e)) from e + except Exception as e: + CONFIG["refactorLogger"].error(f"❌ Refactoring error: {e!s}") + CONFIG["refactorLogger"].info(f"{'=' * 100}\n") + raise HTTPException(status_code=400, detail=str(e)) from e + + +def perform_refactoring(source_dir: Path, smell: Smell): + """Executes the refactoring process for a given smell.""" + target_file = Path(smell.path) + + CONFIG["refactorLogger"].info( + f"πŸš€ Starting refactoring for {smell.symbol} at line {smell.occurences[0].line} in {target_file}" + ) + + if not source_dir.is_dir(): + CONFIG["refactorLogger"].error(f"❌ Directory does not exist: {source_dir}") + raise OSError(f"Directory {source_dir} does not exist.") + + initial_emissions = measure_energy(target_file) + + if not initial_emissions: + CONFIG["refactorLogger"].error("❌ Could not retrieve initial emissions.") + raise RuntimeError("Could not retrieve initial emissions.") + + CONFIG["refactorLogger"].info(f"πŸ“Š Initial emissions: {initial_emissions} kg CO2") + + temp_dir = mkdtemp(prefix="ecooptimizer-") + source_copy = Path(temp_dir) / source_dir.name + target_file_copy = Path(str(target_file).replace(str(source_dir), str(source_copy), 1)) + + shutil.copytree(source_dir, source_copy, ignore=shutil.ignore_patterns(".git*")) + + modified_files = [] + try: + modified_files: list[Path] = refactorer_controller.run_refactorer( + target_file_copy, source_copy, smell + ) + except NotImplementedError: + print("Not implemented yet.") + except Exception as e: + print(f"An unexpected error occured: {e!s}") + traceback.print_exc() + shutil.rmtree(temp_dir, onerror=remove_readonly) + raise RefactoringError(str(target_file), str(e)) from e + + final_emissions = measure_energy(target_file_copy) + + if not final_emissions: + print("❌ Could not retrieve final emissions. Discarding refactoring.") + + CONFIG["refactorLogger"].error( + "❌ Could not retrieve final emissions. Discarding refactoring." + ) + + shutil.rmtree(temp_dir, onerror=remove_readonly) + raise RuntimeError("Could not retrieve final emissions.") + + if CONFIG["mode"] == "production" and final_emissions >= initial_emissions: + CONFIG["refactorLogger"].info(f"πŸ“Š Final emissions: {final_emissions} kg CO2") + CONFIG["refactorLogger"].info("⚠️ No measured energy savings. Discarding refactoring.") + + print("❌ Could not retrieve final emissions. Discarding refactoring.") + + shutil.rmtree(temp_dir, onerror=remove_readonly) + raise EnergySavingsError(str(target_file), "Energy was not saved after refactoring.") + + CONFIG["refactorLogger"].info( + f"βœ… Energy saved! Initial: {initial_emissions}, Final: {final_emissions}" + ) + + refactor_data = { + "tempDir": temp_dir, + "targetFile": { + "original": str(target_file.resolve()), + "refactored": str(target_file_copy.resolve()), + }, + "energySaved": initial_emissions - final_emissions + if not math.isnan(initial_emissions - final_emissions) + else None, + "affectedFiles": [ + { + "original": str(file.resolve()).replace( + str(source_copy.resolve()), str(source_dir.resolve()) + ), + "refactored": str(file.resolve()), + } + for file in modified_files + ], + } + + updated_smells = analyzer_controller.run_analysis(target_file_copy) + return refactor_data, updated_smells + + +def measure_energy(file: Path): + energy_meter.measure_energy(file) + return energy_meter.emissions + + +def clean_refactored_data(refactor_data: dict[str, Any]): + """Ensures the refactored data is correctly structured and handles missing fields.""" + try: + return RefactoredData( + tempDir=refactor_data.get("tempDir", ""), + targetFile=ChangedFile( + original=refactor_data["targetFile"].get("original", ""), + refactored=refactor_data["targetFile"].get("refactored", ""), + ), + energySaved=refactor_data.get("energySaved", None), + affectedFiles=[ + ChangedFile( + original=file.get("original", ""), + refactored=file.get("refactored", ""), + ) + for file in refactor_data.get("affectedFiles", []) + ], + ) + except KeyError as e: + CONFIG["refactorLogger"].error(f"❌ Missing expected key in refactored data: {e}") + raise HTTPException(status_code=500, detail=f"Missing key: {e}") from e diff --git a/src/ecooptimizer/api/routes/show_logs.py b/src/ecooptimizer/api/routes/show_logs.py new file mode 100644 index 00000000..d9b1b647 --- /dev/null +++ b/src/ecooptimizer/api/routes/show_logs.py @@ -0,0 +1,90 @@ +# pyright: reportOptionalMemberAccess=false + +import asyncio +from pathlib import Path +from fastapi import APIRouter, WebSocketException +from fastapi.websockets import WebSocketState, WebSocket, WebSocketDisconnect +from pydantic import BaseModel + +from ...utils.output_manager import LoggingManager +from ...config import CONFIG + +router = APIRouter() + + +class LogInit(BaseModel): + log_dir: str + + +@router.post("/logs/init") +def initialize_logs(log_init: LogInit): + try: + loggingManager = LoggingManager(Path(log_init.log_dir), CONFIG["mode"] == "production") + CONFIG["loggingManager"] = loggingManager + CONFIG["detectLogger"] = loggingManager.loggers["detect"] + CONFIG["refactorLogger"] = loggingManager.loggers["refactor"] + + return {"message": "Logging initialized succesfully."} + except Exception as e: + raise WebSocketException(code=500, reason=str(e)) from e + + +@router.websocket("/logs/main") +async def websocket_main_logs(websocket: WebSocket): + await websocket_log_stream(websocket, CONFIG["loggingManager"].log_files["main"]) + + +@router.websocket("/logs/detect") +async def websocket_detect_logs(websocket: WebSocket): + await websocket_log_stream(websocket, CONFIG["loggingManager"].log_files["detect"]) + + +@router.websocket("/logs/refactor") +async def websocket_refactor_logs(websocket: WebSocket): + await websocket_log_stream(websocket, CONFIG["loggingManager"].log_files["refactor"]) + + +async def listen_for_disconnect(websocket: WebSocket): + """Listens for client disconnects.""" + try: + while True: + await websocket.receive() + + if websocket.client_state == WebSocketState.DISCONNECTED: + raise WebSocketDisconnect() + except WebSocketDisconnect: + print("WebSocket disconnected from client.") + raise + except Exception as e: + print(f"Unexpected error in listener: {e}") + + +async def websocket_log_stream(websocket: WebSocket, log_file: Path): + """Streams log file content via WebSocket.""" + await websocket.accept() + + # Start background task to listen for disconnect + listener_task = asyncio.create_task(listen_for_disconnect(websocket)) + + try: + with log_file.open(encoding="utf-8") as file: + file.seek(0, 2) # Start at file end + while not listener_task.done(): + if websocket.application_state != WebSocketState.CONNECTED: + raise WebSocketDisconnect(reason="Connection closed") + + line = file.readline() + if line: + await websocket.send_text(line) + else: + await asyncio.sleep(0.5) + except FileNotFoundError: + await websocket.send_text("Error: Log file not found.") + except WebSocketDisconnect as e: + print(e.reason) + except Exception as e: + print(f"Unexpected error: {e}") + finally: + listener_task.cancel() + if websocket.client_state != WebSocketState.DISCONNECTED: + await websocket.close() diff --git a/src/ecooptimizer/config.py b/src/ecooptimizer/config.py new file mode 100644 index 00000000..af693926 --- /dev/null +++ b/src/ecooptimizer/config.py @@ -0,0 +1,20 @@ +from logging import Logger +import logging +from typing import TypedDict + +from .utils.output_manager import LoggingManager + + +class Config(TypedDict): + mode: str + loggingManager: LoggingManager | None + detectLogger: Logger + refactorLogger: Logger + + +CONFIG: Config = { + "mode": "production", + "loggingManager": None, + "detectLogger": logging.getLogger("detect"), + "refactorLogger": logging.getLogger("refactor"), +} diff --git a/src/ecooptimizer/data_types/__init__.py b/src/ecooptimizer/data_types/__init__.py new file mode 100644 index 00000000..1c130bb6 --- /dev/null +++ b/src/ecooptimizer/data_types/__init__.py @@ -0,0 +1,36 @@ +from .custom_fields import ( + AdditionalInfo, + CRCInfo, + Occurence, + SCLInfo, +) + +from .smell import ( + Smell, + CRCSmell, + SCLSmell, + LECSmell, + LLESmell, + LMCSmell, + LPLSmell, + UVASmell, + MIMSmell, + UGESmell, +) + +__all__ = [ + "AdditionalInfo", + "CRCInfo", + "CRCSmell", + "LECSmell", + "LLESmell", + "LMCSmell", + "LPLSmell", + "MIMSmell", + "Occurence", + "SCLInfo", + "SCLSmell", + "Smell", + "UGESmell", + "UVASmell", +] diff --git a/src/ecooptimizer/data_types/custom_fields.py b/src/ecooptimizer/data_types/custom_fields.py new file mode 100644 index 00000000..f57000f8 --- /dev/null +++ b/src/ecooptimizer/data_types/custom_fields.py @@ -0,0 +1,26 @@ +from typing import Optional +from pydantic import BaseModel + + +class Occurence(BaseModel): + line: int + endLine: int | None + column: int + endColumn: int | None + + +class AdditionalInfo(BaseModel): + innerLoopLine: Optional[int] = None + concatTarget: Optional[str] = None + repetitions: Optional[int] = None + callString: Optional[str] = None + + +class CRCInfo(AdditionalInfo): + callString: str # type: ignore + repetitions: int # type: ignore + + +class SCLInfo(AdditionalInfo): + innerLoopLine: int # type: ignore + concatTarget: str # type: ignore diff --git a/src/ecooptimizer/data_types/smell.py b/src/ecooptimizer/data_types/smell.py new file mode 100644 index 00000000..a12401ce --- /dev/null +++ b/src/ecooptimizer/data_types/smell.py @@ -0,0 +1,50 @@ +from pydantic import BaseModel +from typing import Optional + +from .custom_fields import CRCInfo, Occurence, AdditionalInfo, SCLInfo + + +class Smell(BaseModel): + """ + Represents a code smell detected in a source file, including its location, type, and related metadata. + + Attributes: + confidence (str): The level of confidence for the smell detection (e.g., "high", "medium", "low"). + message (str): A descriptive message explaining the nature of the smell. + messageId (str): A unique identifier for the specific message or warning related to the smell. + module (str): The name of the module or component in which the smell is located. + obj (str): The specific object (e.g., function, class) associated with the smell. + path (str): The relative path to the source file from the project root. + symbol (str): The symbol or code construct (e.g., variable, method) involved in the smell. + type (str): The type or category of the smell (e.g., "complexity", "duplication"). + occurences (list[Occurence]): A list of individual occurences of a same smell, contains positional info. + additionalInfo (AddInfo): (Optional) Any custom information m for a type of smell + """ + + confidence: str + message: str + messageId: str + module: str + obj: str | None + path: str + symbol: str + type: str + occurences: list[Occurence] + additionalInfo: Optional[AdditionalInfo] = None + + +class CRCSmell(Smell): + additionalInfo: CRCInfo # type: ignore + + +class SCLSmell(Smell): + additionalInfo: SCLInfo # type: ignore + + +LECSmell = Smell +LLESmell = Smell +LMCSmell = Smell +LPLSmell = Smell +UVASmell = Smell +MIMSmell = Smell +UGESmell = Smell diff --git a/src/ecooptimizer/data_types/smell_record.py b/src/ecooptimizer/data_types/smell_record.py new file mode 100644 index 00000000..31736939 --- /dev/null +++ b/src/ecooptimizer/data_types/smell_record.py @@ -0,0 +1,23 @@ +from typing import Any, Callable, TypedDict + +from ..refactorers.base_refactorer import BaseRefactorer + + +class SmellRecord(TypedDict): + """ + Represents a code smell configuration used for analysis and refactoring details. + + Attributes: + id (str): The unique identifier for the specific smell or rule. + enabled (bool): Indicates whether the smell detection is enabled. + analyzer_method (Any): The method used for analysis. Could be a string (e.g., "pylint") or a Callable (for AST). + refactorer (Type[Any]): The class responsible for refactoring the detected smell. + analyzer_options (dict[str, Any]): Optional configuration options for the analyzer method. + """ + + id: str + enabled: bool + analyzer_method: str + checker: Callable | None # type: ignore + refactorer: type[BaseRefactorer] # type: ignore # Refers to a class, not an instance + analyzer_options: dict[str, Any] # type: ignore diff --git a/src/ecooptimizer/exceptions.py b/src/ecooptimizer/exceptions.py new file mode 100644 index 00000000..298a5327 --- /dev/null +++ b/src/ecooptimizer/exceptions.py @@ -0,0 +1,25 @@ +import os +import stat + + +class RefactoringError(Exception): + """Exception raised for errors that occured during the refcatoring process. + + Attributes: + targetFile -- file being refactored + message -- explanation of the error + """ + + def __init__(self, targetFile: str, message: str) -> None: + self.targetFile = targetFile + super().__init__(message) + + +class EnergySavingsError(RefactoringError): + pass + + +def remove_readonly(func, path, _): # noqa: ANN001 + # "Clear the readonly bit and reattempt the removal" + os.chmod(path, stat.S_IWRITE) # noqa: PTH101 + func(path) diff --git a/src/measurement/measurement_utils.py b/src/ecooptimizer/measurements/__init__.py similarity index 100% rename from src/measurement/measurement_utils.py rename to src/ecooptimizer/measurements/__init__.py diff --git a/src/ecooptimizer/measurements/base_energy_meter.py b/src/ecooptimizer/measurements/base_energy_meter.py new file mode 100644 index 00000000..425b1fc0 --- /dev/null +++ b/src/ecooptimizer/measurements/base_energy_meter.py @@ -0,0 +1,21 @@ +from abc import ABC, abstractmethod +from pathlib import Path + + +class BaseEnergyMeter(ABC): + def __init__(self): + """ + Base class for energy meters to measure the emissions of a given file. + + :param file_path: Path to the file to measure energy consumption. + :param logger: Logger instance to handle log messages. + """ + self.emissions = None + + @abstractmethod + def measure_energy(self, file_path: Path): + """ + Abstract method to measure the energy consumption of the specified file. + Must be implemented by subclasses. + """ + pass diff --git a/src/ecooptimizer/measurements/codecarbon_energy_meter.py b/src/ecooptimizer/measurements/codecarbon_energy_meter.py new file mode 100644 index 00000000..99c0aa83 --- /dev/null +++ b/src/ecooptimizer/measurements/codecarbon_energy_meter.py @@ -0,0 +1,80 @@ +import logging +import os +from pathlib import Path +import sys +import subprocess +import pandas as pd +from tempfile import TemporaryDirectory +from codecarbon import EmissionsTracker + +from .base_energy_meter import BaseEnergyMeter + + +class CodeCarbonEnergyMeter(BaseEnergyMeter): + def __init__(self): + """ + Initializes the CodeCarbonEnergyMeter with a file path and logger. + + :param file_path: Path to the file to measure energy consumption. + :param logger: Logger instance for logging events. + """ + super().__init__() + self.emissions_data = None + + def measure_energy(self, file_path: Path): + """ + Measures the carbon emissions for the specified file by running it with CodeCarbon. + Logs each step and stores the emissions data if available. + """ + logging.info(f"Starting CodeCarbon energy measurement on {file_path.name}") + + with TemporaryDirectory() as custom_temp_dir: + os.environ["TEMP"] = custom_temp_dir # For Windows + os.environ["TMPDIR"] = custom_temp_dir # For Unix-based systems + + # TODO: Save to logger so doesn't print to console + tracker = EmissionsTracker( + output_dir=custom_temp_dir, + allow_multiple_runs=True, + tracking_mode="process", + log_level="error", + ) # type: ignore + tracker.start() + + try: + subprocess.run( + [sys.executable, file_path], capture_output=True, text=True, check=True + ) + logging.info("CodeCarbon measurement completed successfully.") + except subprocess.CalledProcessError as e: + logging.error(f"Error executing file '{file_path}': {e}") + finally: + self.emissions = tracker.stop() + emissions_file = custom_temp_dir / Path("emissions.csv") + + if emissions_file.exists(): + self.emissions_data = self.extract_emissions_csv(emissions_file) + else: + logging.error( + "Emissions file was not created due to an error during execution." + ) + self.emissions_data = None + + def extract_emissions_csv(self, csv_file_path: Path): + """ + Extracts emissions data from a CSV file generated by CodeCarbon. + + :param csv_file_path: Path to the CSV file. + :return: Dictionary containing the last row of emissions data or None if an error occurs. + """ + str_csv_path = str(csv_file_path) + if csv_file_path.exists(): + try: + df = pd.read_csv(str_csv_path) + return df.to_dict(orient="records")[-1] + except Exception as e: + logging.info(f"Error reading file '{str_csv_path}': {e}") + return None + else: + logging.info(f"File '{str_csv_path}' does not exist.") + return None diff --git a/src/ecooptimizer/refactorers/__init__.py b/src/ecooptimizer/refactorers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/ecooptimizer/refactorers/base_refactorer.py b/src/ecooptimizer/refactorers/base_refactorer.py new file mode 100644 index 00000000..e0d0c3b7 --- /dev/null +++ b/src/ecooptimizer/refactorers/base_refactorer.py @@ -0,0 +1,23 @@ +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Generic, TypeVar + +from ..data_types.smell import Smell + +T = TypeVar("T", bound=Smell) + + +class BaseRefactorer(ABC, Generic[T]): + def __init__(self): + self.modified_files: list[Path] = [] + + @abstractmethod + def refactor( + self, + target_file: Path, + source_dir: Path, + smell: T, + output_file: Path, + overwrite: bool = True, + ): + pass diff --git a/src/ecooptimizer/refactorers/concrete/__init__.py b/src/ecooptimizer/refactorers/concrete/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/ecooptimizer/refactorers/concrete/list_comp_any_all.py b/src/ecooptimizer/refactorers/concrete/list_comp_any_all.py new file mode 100644 index 00000000..7b590abb --- /dev/null +++ b/src/ecooptimizer/refactorers/concrete/list_comp_any_all.py @@ -0,0 +1,88 @@ +import libcst as cst +from pathlib import Path +from libcst.metadata import PositionProvider + +from ..base_refactorer import BaseRefactorer +from ...data_types.smell import UGESmell + + +class ListCompInAnyAllTransformer(cst.CSTTransformer): + METADATA_DEPENDENCIES = (PositionProvider,) + + def __init__(self, target_line: int, start_col: int, end_col: int): + super().__init__() + self.target_line = target_line + self.start_col = start_col + self.end_col = end_col + self.found = False + + def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.BaseExpression: + """ + Detects `any([...])` or `all([...])` calls and converts their list comprehension argument + to a generator expression. + """ + if self.found: + return updated_node # Avoid modifying multiple nodes in one pass + + # Check if the function is `any` or `all` + if isinstance(original_node.func, cst.Name) and original_node.func.value in {"any", "all"}: + # Ensure it has exactly one argument + if len(original_node.args) == 1: + arg = original_node.args[0].value # Extract the argument expression + + # Ensure the argument is a list comprehension + if isinstance(arg, cst.ListComp): + metadata = self.get_metadata(PositionProvider, original_node, None) + if ( + metadata and metadata.start.line == self.target_line + # and self.start_col <= metadata.start.column < self.end_col + ): + self.found = True + return updated_node.with_changes( + args=[ + updated_node.args[0].with_changes( + value=cst.GeneratorExp( + elt=arg.elt, for_in=arg.for_in, lpar=[], rpar=[] + ) + ) + ] + ) + + return updated_node + + +class UseAGeneratorRefactorer(BaseRefactorer[UGESmell]): + def __init__(self): + super().__init__() + + def refactor( + self, + target_file: Path, + source_dir: Path, # noqa: ARG002 + smell: UGESmell, + output_file: Path, + overwrite: bool = True, + ): + """ + Refactors an unnecessary list comprehension inside `any()` or `all()` calls + by converting it to a generator expression. + """ + line_number = smell.occurences[0].line + start_column = smell.occurences[0].column + end_column = smell.occurences[0].endColumn + + # Read the source file + source_code = target_file.read_text() + + # Parse with LibCST + wrapper = cst.MetadataWrapper(cst.parse_module(source_code)) + + # Apply transformation + transformer = ListCompInAnyAllTransformer(line_number, start_column, end_column) # type: ignore + modified_tree = wrapper.visit(transformer) + + if transformer.found: + if overwrite: + target_file.write_text(modified_tree.code) + else: + output_file.write_text(modified_tree.code) diff --git a/src/ecooptimizer/refactorers/concrete/long_element_chain.py b/src/ecooptimizer/refactorers/concrete/long_element_chain.py new file mode 100644 index 00000000..dc246e3d --- /dev/null +++ b/src/ecooptimizer/refactorers/concrete/long_element_chain.py @@ -0,0 +1,343 @@ +import ast +import json +from pathlib import Path +import re +from typing import Any, Optional + +from ..multi_file_refactorer import MultiFileRefactorer +from ...data_types.smell import LECSmell + + +class DictAccess: + """Represents a dictionary access pattern found in code.""" + + def __init__( + self, + dictionary_name: str, + full_access: str, + nesting_level: int, + line_number: int, + col_offset: int, + path: Path, + node: ast.AST, + ): + self.dictionary_name = dictionary_name + self.full_access = full_access + self.nesting_level = nesting_level + self.col_offset = col_offset + self.line_number = line_number + self.path = path + self.node = node + + +class LongElementChainRefactorer(MultiFileRefactorer[LECSmell]): + """ + Refactors long element chains by flattening nested dictionaries. + Only implements flatten dictionary strategy as it proved most effective for energy savings. + """ + + def __init__(self): + super().__init__() + self.dict_name: set[str] = set() + self.access_patterns: set[DictAccess] = set() + self.min_value = float("inf") + self.dict_assignment: Optional[dict[str, Any]] = None + self.initial_parsing = True + + def refactor( + self, + target_file: Path, + source_dir: Path, + smell: LECSmell, + output_file: Path, # noqa: ARG002 + overwrite: bool = True, # noqa: ARG002 + ) -> None: + """Main refactoring method that processes the target file and related files.""" + self.target_file = target_file + line_number = smell.occurences[0].line + + tree = ast.parse(target_file.read_text()) + self._find_dict_names(tree, line_number) + + # Abort if dictionary access is too shallow + self.traverse_and_process(source_dir) + if self.min_value <= 1: + return + + self.initial_parsing = False + self.traverse_and_process(source_dir) + + def _find_dict_names(self, tree: ast.AST, line_number: int) -> None: + """Extract dictionary names from the AST at the given line number.""" + for node in ast.walk(tree): + if not ( + isinstance(node, ast.Subscript) + and hasattr(node, "lineno") + and node.lineno == line_number + ): + continue + + if isinstance(node.value, ast.Name): + self.dict_name.add(node.value.id) + else: + dict_name = self._extract_dict_name(node.value) + if dict_name: + self.dict_name.add(dict_name) + self.dict_name.add(dict_name.split(".")[-1]) + + def _extract_dict_name(self, node: ast.AST) -> Optional[str]: + """Extract dictionary name from attribute access chains.""" + while isinstance(node, ast.Subscript): + node = node.value + + if isinstance(node, ast.Attribute): + return f"{node.value.id}.{node.attr}" + return None + + def _process_file(self, file: Path): + tree = ast.parse(file.read_text()) + if self.initial_parsing: + self._find_access_pattern_in_file(tree, file) + else: + self.find_dict_assignment_in_file(tree) + if self._refactor_all_in_file(file): + return True + + return False + + # finds all access patterns in the file + def _find_access_pattern_in_file(self, tree: ast.AST, path: Path): + offset = set() + for node in ast.walk(tree): + if isinstance(node, ast.Subscript): # Check for dictionary access (Subscript) + dict_name, full_access, line_number, col_offset = self.extract_full_dict_access( + node + ) + + if (line_number, col_offset) in offset: + continue + offset.add((line_number, col_offset)) + + if dict_name.split(".")[-1] in self.dict_name: + nesting_level = self._count_nested_subscripts(node) + access = DictAccess( + dict_name, full_access, nesting_level, line_number, col_offset, path, node + ) + self.access_patterns.add(access) + print(self.access_patterns) + self.min_value = min(self.min_value, nesting_level) + + def extract_full_dict_access(self, node: ast.Subscript): + """Extracts the full dictionary access chain as a string.""" + access_chain = [] + curr = node + # Traverse nested subscripts to build access path + while isinstance(curr, ast.Subscript): + if isinstance(curr.slice, ast.Constant): # Python 3.8+ + access_chain.append(f"['{curr.slice.value}']") + curr = curr.value # Move to parent node + + # Get the dictionary root (can be a variable or an attribute) + if isinstance(curr, ast.Name): + dict_name = curr.id # Simple variable (e.g., "long_chain") + elif isinstance(curr, ast.Attribute) and isinstance(curr.value, ast.Name): + dict_name = f"{curr.value.id}.{curr.attr}" # Attribute access (e.g., "self.long_chain") + else: + dict_name = "UNKNOWN" + + full_access = f"{dict_name}{''.join(reversed(access_chain))}" + + return dict_name, full_access, curr.lineno, curr.col_offset + + def _count_nested_subscripts(self, node: ast.Subscript): + """ + Counts how many times a dictionary is accessed (nested Subscript nodes). + """ + level = 0 + curr = node + while isinstance(curr, ast.Subscript): + curr = curr.value # Move up the AST + level += 1 + return level + + def find_dict_assignment_in_file(self, tree: ast.AST): + """find the dictionary assignment from AST based on the dict name""" + + class DictVisitor(ast.NodeVisitor): + def visit_Assign(self_, node: ast.Assign): + if isinstance(node.value, ast.Dict) and len(node.targets) == 1: + # dictionary is a varibale + if ( + isinstance(node.targets[0], ast.Name) + and node.targets[0].id in self.dict_name + ): + dict_value = self.extract_dict_literal(node.value) + flattened_version = self.flatten_dict(dict_value) # type: ignore + self.dict_assignment = flattened_version + + # dictionary is an attribute + elif ( + isinstance(node.targets[0], ast.Attribute) + and node.targets[0].attr in self.dict_name + ): + dict_value = self.extract_dict_literal(node.value) + self.dict_assignment = self.flatten_dict(dict_value) # type: ignore + self_.generic_visit(node) + + DictVisitor().visit(tree) + + def extract_dict_literal(self, node: ast.AST): + """Convert AST dict literal to Python dict.""" + if isinstance(node, ast.Dict): + return { + self.extract_dict_literal(k) + if isinstance(k, ast.AST) + else k: self.extract_dict_literal(v) if isinstance(v, ast.AST) else v + for k, v in zip(node.keys, node.values) + } + elif isinstance(node, ast.Constant): + return node.value + elif isinstance(node, ast.Name): + return node.id + return node + + def flatten_dict( + self, d: dict[str, Any], depth: int = 0, parent_key: str = "" + ) -> dict[str, Any]: + """Recursively flatten a nested dictionary.""" + + if depth >= self.min_value - 1: + # At max_depth, we return the current dictionary as flattened key-value pairs + items = {} + for k, v in d.items(): + new_key = f"{parent_key}_{k}" if parent_key else k + items[new_key] = v + return items + + items = {} + for k, v in d.items(): + new_key = f"{parent_key}_{k}" if parent_key else k + + if isinstance(v, dict): + # Recursively flatten the dictionary, increasing the depth + items.update(self.flatten_dict(v, depth + 1, new_key)) + else: + # If it's not a dictionary, just add it to the result + items[new_key] = v + + return items + + def generate_flattened_access(self, access_chain: list[str]) -> str: + """Generate flattened dictionary key only until given min_value.""" + + joined = "_".join(k.strip("'\"") for k in access_chain[: self.min_value]) + if not joined.endswith("']") or not joined.endswith('"]'): # Corrected to check for "']" + joined += "']" + remaining = access_chain[self.min_value :] # Keep the rest unchanged + + rest = "".join(f"[{key}]" for key in remaining) + + return f"{joined}" + rest + + def _refactor_all_in_file(self, file_path: Path): + """Refactor dictionary access patterns in a single file.""" + # Skip if no access patterns found + if not any(access.path == file_path for access in self.access_patterns): + return False + + source_code = file_path.read_text() + lines = source_code.split("\n") + line_modifications = self._collect_line_modifications(file_path) + + refactored_lines = self._apply_modifications(lines, line_modifications) + refactored_lines = self._update_dict_assignment(refactored_lines) + + # Write changes back to file + file_path.write_text("\n".join(refactored_lines)) + + return True + + def _collect_line_modifications(self, file_path: Path) -> dict[int, list[tuple[int, str, str]]]: + """Collect all modifications needed for each line.""" + modifications: dict[int, list[tuple[int, str, str]]] = {} + + for access in sorted(self.access_patterns, key=lambda a: (a.line_number, a.col_offset)): + if access.path != file_path: + continue + + access_chain = access.full_access.split("][") + for i in range(len(access_chain)): + access_chain[i] = access_chain[i].replace("]", "") + new_access = self.generate_flattened_access(access_chain) + + if access.line_number not in modifications: + modifications[access.line_number] = [] + modifications[access.line_number].append( + (access.col_offset, access.full_access, new_access) + ) + + return modifications + + def _apply_modifications( + self, lines: list[str], modifications: dict[int, list[tuple[int, str, str]]] + ) -> list[str]: + """Apply collected modifications to each line.""" + refactored_lines = [] + for line_num, original_line in enumerate(lines, start=1): + if line_num in modifications: + # Sort modifications by column offset (reverse to replace from right to left) + mods = sorted(modifications[line_num], key=lambda x: x[0], reverse=True) + modified_line = original_line + # print("this si the og line: " + modified_line) + + for col_offset, old_access, new_access in mods: + end_idx = col_offset + len(old_access) + # Replace specific occurrence using slicing + modified_line = ( + modified_line[:col_offset] + new_access + modified_line[end_idx:] + ) + # print(modified_line) + + refactored_lines.append(modified_line) + else: + # No modification, add original line + refactored_lines.append(original_line) + + return refactored_lines + + def _update_dict_assignment(self, refactored_lines: list[str]) -> None: + """Update dictionary assignment to be the new flattened dictionary.""" + dictionary_assignment_name = self.dict_name + for i, line in enumerate(refactored_lines): + match = next( + ( + name + for name in dictionary_assignment_name + if re.match(rf"^\s*(?:\w+\.)*{re.escape(name)}\s*=", line) + ), + None, + ) + + if match: + # Preserve indentation and the `=` + indent, prefix, _ = re.split(r"(=)", line, maxsplit=1) + + # Convert dict to a properly formatted string + dict_str = json.dumps(self.dict_assignment, separators=(",", ": ")) + # Update the line with the new flattened dictionary + refactored_lines[i] = f"{indent}{prefix} {dict_str}" + + # Remove the following lines of the original nested dictionary, + # leaving only one empty line after them + j = i + 1 + while j < len(refactored_lines) and ( + refactored_lines[j].strip().startswith('"') + or refactored_lines[j].strip().startswith("}") + ): + refactored_lines[j] = "Remove this line" # Mark for removal + j += 1 + break + + refactored_lines = [line for line in refactored_lines if line.strip() != "Remove this line"] + + return refactored_lines diff --git a/src/ecooptimizer/refactorers/concrete/long_lambda_function.py b/src/ecooptimizer/refactorers/concrete/long_lambda_function.py new file mode 100644 index 00000000..76c5e6bc --- /dev/null +++ b/src/ecooptimizer/refactorers/concrete/long_lambda_function.py @@ -0,0 +1,153 @@ +from pathlib import Path +import re +from ..base_refactorer import BaseRefactorer +from ...data_types.smell import LLESmell + + +class LongLambdaFunctionRefactorer(BaseRefactorer[LLESmell]): + """ + Refactorer that targets long lambda functions by converting them into normal functions. + """ + + def __init__(self) -> None: + super().__init__() + + @staticmethod + def truncate_at_top_level_comma(body: str) -> str: + """ + Truncate the lambda body at the first top-level comma, ignoring commas + within nested parentheses, brackets, or braces. + """ + truncated_body = [] + open_parens = 0 + + for char in body: + if char in "([{": + open_parens += 1 + elif char in ")]}": + open_parens -= 1 + elif char == "," and open_parens == 0: + # Stop at the first top-level comma + break + + truncated_body.append(char) + + return "".join(truncated_body).strip() + + def refactor( + self, + target_file: Path, + source_dir: Path, # noqa: ARG002 + smell: LLESmell, + output_file: Path, + overwrite: bool = True, + ): + """ + Refactor long lambda functions by converting them into normal functions + and writing the refactored code to a new file. + """ + # Extract details from smell + line_number = smell.occurences[0].line + + # Read the original file + content = target_file.read_text(encoding="utf-8") + lines = content.splitlines(keepends=True) + + # Capture the entire logical line containing the lambda + current_line = line_number - 1 + lambda_lines = [lines[current_line].rstrip()] + + # Check if lambda is wrapped in parentheses + has_parentheses = lambda_lines[0].strip().startswith("(") + + # Find continuation lines only if needed + if has_parentheses: + while current_line < len(lines) - 1 and not lambda_lines[ + -1 + ].strip().endswith(")"): + current_line += 1 + lambda_lines.append(lines[current_line].rstrip()) + else: + # Handle single-line lambda + lambda_lines = [lines[current_line].rstrip()] + + full_lambda_line = " ".join(lambda_lines).strip() + + # Remove surrounding parentheses if present + if has_parentheses: + full_lambda_line = re.sub(r"^\((.*)\)$", r"\1", full_lambda_line) + + # Extract leading whitespace for correct indentation + original_indent = re.match(r"^\s*", lambda_lines[0]).group() # type: ignore + + # Use different regex based on whether the lambda line starts with a parenthesis + if has_parentheses: + lambda_match = re.search( + r"lambda\s+([\w, ]+):\s+(.+?)(?=\s*\))", full_lambda_line + ) + else: + lambda_match = re.search(r"lambda\s+([\w, ]+):\s+(.+)", full_lambda_line) + + if not lambda_match: + return + + # Extract arguments and body of the lambda + lambda_args = lambda_match.group(1).strip() + lambda_body_before = lambda_match.group(2).strip() + lambda_body_before = LongLambdaFunctionRefactorer.truncate_at_top_level_comma( + lambda_body_before + ) + + # Ensure that the lambda body does not contain extra trailing characters + # Remove any trailing commas or mismatched closing brackets + lambda_body = re.sub(r",\s*\)$", "", lambda_body_before).strip() + + lambda_body_no_extra_space = re.sub(r"\s{2,}", " ", lambda_body) + # Generate a unique function name + function_name = f"converted_lambda_{line_number}" + + # Find the start of the block containing the lambda + original_indent_len = len(original_indent) + block_start = line_number - 1 + while block_start > 0: + prev_line = lines[block_start - 1].rstrip() + prev_indent = len(re.match(r"^\s*", prev_line).group()) # type: ignore + if prev_line.endswith(":") and prev_indent < original_indent_len: + break + block_start -= 1 + + # Get proper block indentation + block_indentation = re.match(r"^\s*", lines[block_start]).group() # type: ignore + function_indent = block_indentation + body_indent = function_indent + " " * 4 + + # Create properly indented function definition + function_def = ( + f"{function_indent}def {function_name}({lambda_args}):\n" + f"{body_indent}result = {lambda_body_no_extra_space}\n" + f"{body_indent}return result\n\n" + ) + + # Prepare refactored line with original indentation + replacement_line = full_lambda_line.replace( + f"lambda {lambda_args}: {lambda_body}", function_name + ) + refactored_line = f"{original_indent}{replacement_line.strip()}" + + # Split multi-line function definition into individual lines + function_lines = function_def.splitlines(keepends=True) + + # Replace the lambda line with the refactored line in place + lines[current_line] = f"{refactored_line}\n" + + # Insert the new function definition immediately at the beginning of the block + lines.insert(block_start, "".join(function_lines)) + + # Write changes + new_content = "".join(lines) + if overwrite: + target_file.write_text(new_content, encoding="utf-8") + else: + output_file.write_text(new_content, encoding="utf-8") + + self.modified_files.append(target_file) diff --git a/src/ecooptimizer/refactorers/concrete/long_message_chain.py b/src/ecooptimizer/refactorers/concrete/long_message_chain.py new file mode 100644 index 00000000..5f7f9738 --- /dev/null +++ b/src/ecooptimizer/refactorers/concrete/long_message_chain.py @@ -0,0 +1,137 @@ +from pathlib import Path +import re +from ..base_refactorer import BaseRefactorer +from ...data_types.smell import LMCSmell + + +class LongMessageChainRefactorer(BaseRefactorer[LMCSmell]): + """ + Refactorer that targets long method chains to improve performance. + """ + + def __init__(self) -> None: + super().__init__() + + def refactor( + self, + target_file: Path, + source_dir: Path, # noqa: ARG002 + smell: LMCSmell, + output_file: Path, + overwrite: bool = True, + ): + """ + Refactor long message chains by breaking them into separate statements + and writing the refactored code to a new file. + """ + # Extract details from smell + line_number = smell.occurences[0].line + # temp_filename = output_file + + # Read file content using read_text + content = target_file.read_text(encoding="utf-8") + lines = content.splitlines(keepends=True) # Preserve line endings + + # Identify the line with the long method chain + line_with_chain = lines[line_number - 1].rstrip() + + # Extract leading whitespace for correct indentation + leading_whitespace = re.match(r"^\s*", line_with_chain).group() # type: ignore + + # Check if the line contains an f-string + f_string_pattern = r"f\".*?\"" + if re.search(f_string_pattern, line_with_chain): + # Determine if original was print or assignment + is_print = line_with_chain.startswith("print(") + original_var = None if is_print else line_with_chain.split("=", 1)[0].strip() + + # Extract f-string and methods + f_string_content = re.search(f_string_pattern, line_with_chain).group() # type: ignore + remaining_chain = line_with_chain.split(f_string_content, 1)[-1].lstrip(".") + + method_calls = re.split(r"\.(?![^()]*\))", remaining_chain.strip()) + refactored_lines = [] + + # Initial f-string assignment + refactored_lines.append(f"{leading_whitespace}intermediate_0 = {f_string_content}") + + # Process method calls + for i, method in enumerate(method_calls, start=1): + method = method.strip() + if not method: + continue + + if i < len(method_calls): + refactored_lines.append( + f"{leading_whitespace}intermediate_{i} = " f"intermediate_{i-1}.{method}" + ) + else: + # Final assignment using original variable name + if is_print: + refactored_lines.append( + f"{leading_whitespace}print(intermediate_{i-1}.{method})" + ) + else: + refactored_lines.append( + f"{leading_whitespace}{original_var} = " f"intermediate_{i-1}.{method}" + ) + + lines[line_number - 1] = "\n".join(refactored_lines) + "\n" + + else: + # Handle non-f-string chains + original_has_print = "print(" in line_with_chain + chain_content = re.sub(r"^\s*print\((.*)\)\s*$", r"\1", line_with_chain) + + # Extract RHS if assignment exists + if "=" in chain_content: + chain_content = chain_content.split("=", 1)[1].strip() + + # Split chain after closing parentheses + method_calls = re.split(r"(?<=\))\.", chain_content) + + if len(method_calls) > 1: + refactored_lines = [] + base_var = method_calls[0].strip() + refactored_lines.append(f"{leading_whitespace}intermediate_0 = {base_var}") + + # Process subsequent method calls + for i, method in enumerate(method_calls[1:], start=1): + method = method.strip().lstrip(".") + if not method: + continue + + if i < len(method_calls) - 1: + refactored_lines.append( + f"{leading_whitespace}intermediate_{i} = " + f"intermediate_{i-1}.{method}" + ) + else: + # Preserve original assignment/print structure + if original_has_print: + refactored_lines.append( + f"{leading_whitespace}print(intermediate_{i-1}.{method})" + ) + else: + original_assignment = line_with_chain.split("=", 1)[0].strip() + refactored_lines.append( + f"{leading_whitespace}{original_assignment} = " + f"intermediate_{i-1}.{method}" + ) + + lines[line_number - 1] = "\n".join(refactored_lines) + "\n" + + # # Write the refactored file + # with temp_filename.open("w") as f: + # f.writelines(lines) + + # Join lines and write using write_text + new_content = "".join(lines) + + # Write to appropriate file based on overwrite flag + if overwrite: + target_file.write_text(new_content, encoding="utf-8") + else: + output_file.write_text(new_content, encoding="utf-8") + + self.modified_files.append(target_file) diff --git a/src/ecooptimizer/refactorers/concrete/long_parameter_list.py b/src/ecooptimizer/refactorers/concrete/long_parameter_list.py new file mode 100644 index 00000000..4b1205d8 --- /dev/null +++ b/src/ecooptimizer/refactorers/concrete/long_parameter_list.py @@ -0,0 +1,635 @@ +import libcst as cst +import libcst.matchers as m +from libcst.metadata import PositionProvider, MetadataWrapper, ParentNodeProvider +from pathlib import Path +from typing import Optional +from collections.abc import Mapping + +from ..multi_file_refactorer import MultiFileRefactorer +from ...data_types.smell import LPLSmell + + +class FunctionCallVisitor(cst.CSTVisitor): + def __init__(self, function_name: str, class_name: str, is_constructor: bool): + self.function_name = function_name + self.is_constructor = is_constructor # whether or not given function call is a constructor + self.class_name = ( + class_name # name of class being instantiated if function is a constructor + ) + self.found = False + + def visit_Call(self, node: cst.Call): + """Check if the function/class constructor is called.""" + # handle class constructor call + if self.is_constructor and m.matches(node.func, m.Name(self.class_name)): + self.found = True + + # handle standalone function calls + elif m.matches(node.func, m.Name(self.function_name)): + self.found = True + + # handle method calss + elif m.matches(node.func, m.Attribute(attr=m.Name(self.function_name))): + self.found = True + + +class ParameterAnalyzer: + @staticmethod + def get_used_parameters(function_node: cst.FunctionDef, params: list[str]) -> list[str]: + """ + Identifies parameters that actually are used within the function/method body using CST analysis + """ + + # visitor class to collect variable names used in the function body + class UsedParamVisitor(cst.CSTVisitor): + def __init__(self): + self.used_names = set() + + def visit_Name(self, node: cst.Name) -> None: + self.used_names.add(node.value) + + # traverse the function body to collect used variable names + visitor = UsedParamVisitor() + function_node.body.visit(visitor) + + return [name for name in params if name in visitor.used_names] + + @staticmethod + def get_parameters_with_default_value(params: list[cst.Param]) -> dict[str, cst.Arg]: + """ + Given a list of function parameters and their default values, maps parameter names to their default values + """ + param_defaults = {} + + for param in params: + if param.default is not None: # check if the parameter has a default value + param_defaults[param.name.value] = param.default + + return param_defaults + + @staticmethod + def classify_parameters(params: list[str]) -> dict[str, list[str]]: + """ + Classifies parameters into 'data' and 'config' groups based on naming conventions + """ + data_params, config_params = [], [] + data_keywords = {"data", "input", "output", "result", "record", "item"} + config_keywords = {"config", "setting", "option", "env", "parameter", "path"} + + for param in params: + param_lower = param.lower() + if any(keyword in param_lower for keyword in data_keywords): + data_params.append(param) + elif any(keyword in param_lower for keyword in config_keywords): + config_params.append(param) + else: + data_params.append(param) + return {"data_params": data_params, "config_params": config_params} + + +class ParameterEncapsulator: + @staticmethod + def encapsulate_parameters( + classified_params: dict[str, list[str]], + default_value_params: dict[str, cst.Arg], + classified_param_names: tuple[str, str], + ) -> list[cst.ClassDef]: + """ + Generates CST class definitions for encapsulating parameter objects. + """ + data_params, config_params = ( + classified_params["data_params"], + classified_params["config_params"], + ) + class_nodes = [] + + data_class_name, config_class_name = classified_param_names + + if data_params: + data_param_class = ParameterEncapsulator.create_parameter_object_class( + data_params, default_value_params, data_class_name + ) + class_nodes.append(data_param_class) + + if config_params: + config_param_class = ParameterEncapsulator.create_parameter_object_class( + config_params, default_value_params, config_class_name + ) + class_nodes.append(config_param_class) + + return class_nodes + + @staticmethod + def create_parameter_object_class( + param_names: list[str], + default_value_params: dict[str, cst.Arg], + class_name: str = "ParamsObject", + ) -> cst.ClassDef: + """ + Creates a CST class definition for encapsulating related parameters. + """ + # create constructor parameters + constructor_params = [cst.Param(name=cst.Name("self"))] + assignments = [] + + for param in param_names: + default_value = default_value_params.get(param, None) + + param_cst = cst.Param( + name=cst.Name(param), + default=default_value, # set default value if available + ) + constructor_params.append(param_cst) + + assignment = cst.SimpleStatementLine( + [ + cst.Assign( + targets=[ + cst.AssignTarget( + cst.Attribute(value=cst.Name("self"), attr=cst.Name(param)) + ) + ], + value=cst.Name(param), + ) + ] + ) + assignments.append(assignment) + + constructor = cst.FunctionDef( + name=cst.Name("__init__"), + params=cst.Parameters(params=constructor_params), + body=cst.IndentedBlock(body=assignments), + ) + + # create class definition + return cst.ClassDef( + name=cst.Name(class_name), + body=cst.IndentedBlock(body=[constructor]), + ) + + +class FunctionCallUpdater: + @staticmethod + def get_method_type(func_node: cst.FunctionDef) -> str: + """ + Determines whether a function is an instance method, class method, or static method + """ + # check for @staticmethod or @classmethod decorators + for decorator in func_node.decorators: + if isinstance(decorator.decorator, cst.Name): + if decorator.decorator.value == "staticmethod": + return "static method" + if decorator.decorator.value == "classmethod": + return "class method" + + # check the first parameter name + if func_node.params.params: + first_param = func_node.params.params[0].name.value + if first_param == "self": + return "instance method" + if first_param == "cls": + return "class method" + + return "unknown method type" + + @staticmethod + def remove_unused_params( + function_node: cst.FunctionDef, + used_params: list[str], + default_value_params: dict[str, cst.Arg], + ) -> cst.FunctionDef: + """ + Removes unused parameters from the function signature while preserving self/cls if applicable. + Ensures there is no trailing comma when removing the last parameter. + """ + method_type = FunctionCallUpdater.get_method_type(function_node) + + updated_params = [] + updated_defaults = [] + + # preserve self/cls if it's an instance or class method + if function_node.params.params and method_type in {"instance method", "class method"}: + updated_params.append(function_node.params.params[0]) + + # remove unused parameters, keeping only those that are used + for param in function_node.params.params: + if param.name.value in used_params: + updated_params.append(param) + if param.name.value in default_value_params: + updated_defaults.append(default_value_params[param.name.value]) + + # ensure that the last parameter does not leave a trailing comma + updated_params = [p.with_changes(comma=cst.MaybeSentinel.DEFAULT) for p in updated_params] + + return function_node.with_changes( + params=function_node.params.with_changes(params=updated_params) + ) + + @staticmethod + def update_function_signature( + function_node: cst.FunctionDef, classified_params: dict[str, list[str]] + ) -> cst.FunctionDef: + """ + Updates the function signature to use encapsulated parameter objects + """ + data_params, config_params = ( + classified_params["data_params"], + classified_params["config_params"], + ) + + method_type = FunctionCallUpdater.get_method_type(function_node) + new_params = [] + + # preserve self/cls if it's a method + if function_node.params.params and method_type in {"instance method", "class method"}: + new_params.append(function_node.params.params[0]) + + # add encapsulated objects as new parameters + if data_params: + new_params.append(cst.Param(name=cst.Name("data_params"))) + if config_params: + new_params.append(cst.Param(name=cst.Name("config_params"))) + + return function_node.with_changes( + params=function_node.params.with_changes(params=new_params) + ) + + @staticmethod + def update_parameter_usages( + function_node: cst.FunctionDef, classified_params: dict[str, list[str]] + ): + """ + Updates the function body to use encapsulated parameter objects. + """ + + class ParameterUsageTransformer(cst.CSTTransformer): + def __init__(self, classified_params: dict[str, list[str]]): + self.param_to_group = {} + + # flatten classified_params to map each param to its group (dataParams or configParams) + for group, params in classified_params.items(): + for param in params: + self.param_to_group[param] = group + + def leave_Assign( + self, + original_node: cst.Assign, # noqa: ARG002 + updated_node: cst.Assign, + ) -> cst.Assign: + """ + Transform only right-hand side references to parameters that need to be updated. + Ensure left-hand side (self attributes) remain unchanged. + """ + if not isinstance(updated_node.value, cst.Name): + return updated_node + + var_name = updated_node.value.value + + if var_name in self.param_to_group: + new_value = cst.Attribute( + value=cst.Name(self.param_to_group[var_name]), attr=cst.Name(var_name) + ) + return updated_node.with_changes(value=new_value) + + return updated_node + + # wrap CST node in a MetadataWrapper to enable metadata analysis + transformer = ParameterUsageTransformer(classified_params) + return function_node.visit(transformer) + + @staticmethod + def get_enclosing_class_name( + tree: cst.Module, # noqa: ARG004 + init_node: cst.FunctionDef, + parent_metadata: Mapping[cst.CSTNode, cst.CSTNode], + ) -> Optional[str]: + """ + Finds the class name enclosing the given __init__ function node. + """ + # wrapper = MetadataWrapper(tree) + current_node = init_node + while current_node in parent_metadata: + parent = parent_metadata[current_node] + if isinstance(parent, cst.ClassDef): + return parent.name.value + current_node = parent + return None + + @staticmethod + def update_function_calls( + tree: cst.Module, + function_node: cst.FunctionDef, + used_params: list[str], + classified_params: dict[str, list[str]], + classified_param_names: tuple[str, str], + enclosing_class_name: str, + ) -> cst.Module: + """ + Updates all calls to a given function in the provided CST tree to reflect new encapsulated parameters + :param tree: CST tree of the code. + :param function_node: CST node of the function to update calls for. + :param params: A dictionary containing 'data' and 'config' parameters. + :return: The updated CST tree + """ + param_to_group = {} + + for group_name, params in zip(classified_param_names, classified_params.values()): + for param in params: + param_to_group[param] = group_name + + function_name = function_node.name.value + if function_name == "__init__": + function_name = enclosing_class_name + + class FunctionCallTransformer(cst.CSTTransformer): + def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call: # noqa: ARG002 + """Transforms function calls to use grouped parameters.""" + # Handle both standalone function calls and instance method calls + if not isinstance(updated_node.func, (cst.Name, cst.Attribute)): + return updated_node # Ignore other calls that are not functions/methods + + # Extract the function/method name + func_name = ( + updated_node.func.attr.value + if isinstance(updated_node.func, cst.Attribute) + else updated_node.func.value + ) + + # If the function/method being called is not the one we're refactoring, skip it + if func_name != function_name: + return updated_node + + positional_args = [] + keyword_args = {} + + # Separate positional and keyword arguments + for arg in updated_node.args: + if arg.keyword is None: + positional_args.append(arg.value) + else: + keyword_args[arg.keyword.value] = arg.value + + # Group arguments based on classified_params + grouped_args = {group: [] for group in classified_param_names} + + # Process positional arguments + param_index = 0 + for param in used_params: + if param_index < len(positional_args): + grouped_args[param_to_group[param]].append( + cst.Arg(value=positional_args[param_index]) + ) + param_index += 1 + + # Process keyword arguments + for kw, value in keyword_args.items(): + if kw in param_to_group: + grouped_args[param_to_group[kw]].append( + cst.Arg(value=value, keyword=cst.Name(kw)) + ) + + # Construct new grouped arguments + new_args = [ + cst.Arg( + value=cst.Call(func=cst.Name(group_name), args=grouped_args[group_name]) + ) + for group_name in classified_param_names + if grouped_args[group_name] # Skip empty groups + ] + + return updated_node.with_changes(args=new_args) + + transformer = FunctionCallTransformer() + return tree.visit(transformer) + + +class ClassInserter(cst.CSTTransformer): + def __init__(self, class_nodes: list[cst.ClassDef]): + self.class_nodes = class_nodes + self.insert_index = None + + def visit_Module(self, node: cst.Module) -> None: + """ + Identify the first function definition in the module. + """ + for i, statement in enumerate(node.body): + if isinstance(statement, cst.FunctionDef): + self.insert_index = i + break + + def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002 + """ + Insert the generated class definitions before the first function definition. + """ + if self.insert_index is None: + # if no function is found, append the class nodes at the beginning + new_body = list(self.class_nodes) + list(updated_node.body) + else: + # insert class nodes before the first function + new_body = ( + list(updated_node.body[: self.insert_index]) + + list(self.class_nodes) + + list(updated_node.body[self.insert_index :]) + ) + + return updated_node.with_changes(body=new_body) + + +class FunctionFinder(cst.CSTVisitor): + METADATA_DEPENDENCIES = (PositionProvider,) + + def __init__(self, position_metadata, target_line): # noqa: ANN001 + self.position_metadata = position_metadata + self.target_line = target_line + self.function_node = None + + def visit_FunctionDef(self, node: cst.FunctionDef): + """Check if the function's starting line matches the target.""" + pos = self.position_metadata.get(node) + if pos and pos.start.line == self.target_line: + self.function_node = node # Store the function node + + +class LongParameterListRefactorer(MultiFileRefactorer[LPLSmell]): + def __init__(self): + super().__init__() + self.parameter_analyzer = ParameterAnalyzer() + self.parameter_encapsulator = ParameterEncapsulator() + self.function_updater = FunctionCallUpdater() + self.function_node: Optional[cst.FunctionDef] = ( + None # AST node of definition of function that needs to be refactored + ) + self.used_params: None # list of unclassified used params + self.classified_params = None + self.classified_param_names = None + self.classified_param_nodes = [] + self.enclosing_class_name: Optional[str] = None + self.is_constructor = False + + def refactor( + self, + target_file: Path, + source_dir: Path, + smell: LPLSmell, + output_file: Path, + overwrite: bool = True, + ): + """ + Refactors function/method with more than 6 parameters by encapsulating those with related names and removing those that are unused + """ + # maximum limit on number of parameters beyond which the code smell is configured to be detected(see analyzers_config.py) + max_param_limit = 6 + self.target_file = target_file + + with target_file.open() as f: + source_code = f.read() + + tree = cst.parse_module(source_code) + wrapper = MetadataWrapper(tree) + position_metadata = wrapper.resolve(PositionProvider) + parent_metadata = wrapper.resolve(ParentNodeProvider) + target_line = smell.occurences[0].line + + visitor = FunctionFinder(position_metadata, target_line) + wrapper.visit(visitor) # Traverses the CST tree + + if visitor.function_node: + self.function_node = visitor.function_node + + self.is_constructor = self.function_node.name.value == "__init__" + if self.is_constructor: + self.enclosing_class_name = FunctionCallUpdater.get_enclosing_class_name( + tree, self.function_node, parent_metadata + ) + param_names = [ + param.name.value + for param in self.function_node.params.params + if param.name.value != "self" + ] + param_nodes = [ + param for param in self.function_node.params.params if param.name.value != "self" + ] + # params that have default value assigned in function definition, stored as a dict of param name to default value + default_value_params = self.parameter_analyzer.get_parameters_with_default_value( + param_nodes + ) + + if len(param_nodes) > max_param_limit: + # need to identify used parameters so unused ones can be removed + self.used_params = self.parameter_analyzer.get_used_parameters( + self.function_node, param_names + ) + + if len(self.used_params) > max_param_limit: + # classify used params into data and config types and store the results in a dictionary, if number of used params is beyond the configured limit + self.classified_params = self.parameter_analyzer.classify_parameters( + self.used_params + ) + self.classified_param_names = self._generate_unique_param_class_names( + target_line + ) + # add class defitions for data and config encapsulations to the tree + self.classified_param_nodes = ( + self.parameter_encapsulator.encapsulate_parameters( + self.classified_params, + default_value_params, + self.classified_param_names, + ) + ) + + # insert class definitions and update function calls + tree = tree.visit(ClassInserter(self.classified_param_nodes)) + # update calls to the function + tree = self.function_updater.update_function_calls( + tree, + self.function_node, + self.used_params, + self.classified_params, + self.classified_param_names, + self.enclosing_class_name, + ) + # next updaate function signature and parameter usages within function body + updated_function_node = self.function_updater.update_function_signature( + self.function_node, self.classified_params + ) + updated_function_node = self.function_updater.update_parameter_usages( + updated_function_node, self.classified_params + ) + + else: + # just remove the unused params if the used parameters are within the max param list + updated_function_node = self.function_updater.remove_unused_params( + self.function_node, self.used_params, default_value_params + ) + + class FunctionReplacer(cst.CSTTransformer): + def __init__( + self, original_function: cst.FunctionDef, updated_function: cst.FunctionDef + ): + self.original_function = original_function + self.updated_function = updated_function + + def leave_FunctionDef( + self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef + ) -> cst.FunctionDef: + """Replace the original function definition with the updated one.""" + if original_node.deep_equals(self.original_function): + return self.updated_function # replace with the modified function + return updated_node # leave other functions unchanged + + tree = tree.visit(FunctionReplacer(self.function_node, updated_function_node)) + + # Write the modified source + modified_source = tree.code + + with output_file.open("w") as temp_file: + temp_file.write(modified_source) + + if overwrite: + with target_file.open("w") as f: + f.write(modified_source) + + self.traverse_and_process(source_dir) + + def _generate_unique_param_class_names(self, target_line: int) -> tuple[str, str]: + """ + Generate unique class names for data params and config params based on function name and line number. + :return: A tuple containing (DataParams class name, ConfigParams class name). + """ + unique_suffix = f"{self.function_node.name.value}_{target_line}" + data_class_name = f"DataParams_{unique_suffix}" + config_class_name = f"ConfigParams_{unique_suffix}" + return data_class_name, config_class_name + + def _process_file(self, file: Path): + if file.samefile(self.target_file): + return False + + tree = cst.parse_module(file.read_text()) + + visitor = FunctionCallVisitor( + self.function_node.name.value, self.enclosing_class_name, self.is_constructor + ) + tree.visit(visitor) + + if not visitor.found: + return False + + # insert class definitions before modifying function calls + tree = tree.visit(ClassInserter(self.classified_param_nodes)) + + # update function calls/class instantiations + tree = self.function_updater.update_function_calls( + tree, + self.function_node, + self.used_params, + self.classified_params, + self.classified_param_names, + self.enclosing_class_name, + ) + + modified_source = tree.code + with file.open("w") as f: + f.write(modified_source) + + return True diff --git a/src/ecooptimizer/refactorers/concrete/member_ignoring_method.py b/src/ecooptimizer/refactorers/concrete/member_ignoring_method.py new file mode 100644 index 00000000..25c02456 --- /dev/null +++ b/src/ecooptimizer/refactorers/concrete/member_ignoring_method.py @@ -0,0 +1,240 @@ +import astroid +from astroid import nodes, util +import libcst as cst +from libcst.metadata import PositionProvider, MetadataWrapper + +from pathlib import Path + +from ...config import CONFIG + +from ..multi_file_refactorer import MultiFileRefactorer +from ...data_types.smell import MIMSmell + + +class CallTransformer(cst.CSTTransformer): + METADATA_DEPENDENCIES = (PositionProvider,) + + def __init__(self, class_name: str): + self.method_calls: list[tuple[str, int, str, str]] = None # type: ignore + self.class_name = class_name # Class name to replace instance calls + self.transformed = False + + def set_calls(self, valid_calls: list[tuple[str, int, str, str]]): + self.method_calls = valid_calls + + def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call: + """Transform instance calls to static calls if they match.""" + if isinstance(original_node.func, cst.Attribute): + caller = original_node.func.value + method = original_node.func.attr.value + position = self.get_metadata(PositionProvider, original_node, None) + + if not position: + raise TypeError("What do you mean you can't find the position?") + + # Check if this call matches one from astroid (by caller, method name, and line number) + for call_caller, line, call_method, cls in self.method_calls: + CONFIG["refactorLogger"].debug( + f"cst caller: {call_caller} at line {position.start.line}" + ) + if ( + method == call_method + and position.start.line == line + and caller.deep_equals(cst.parse_expression(call_caller)) + ): + CONFIG["refactorLogger"].debug("transforming") + # Transform `obj.method(args)` -> `ClassName.method(args)` + new_func = cst.Attribute( + value=cst.Name(cls), # Replace `obj` with class name + attr=original_node.func.attr, + ) + self.transformed = True + return updated_node.with_changes(func=new_func) + + return updated_node # Return unchanged if no match + + +def find_valid_method_calls( + tree: nodes.Module, mim_method: str, valid_classes: set[str] +) -> list[tuple[str, int, str, str]]: + """ + Finds method calls where the instance is of a valid class. + + Returns: + A list of (caller_name, line_number, method_name). + """ + valid_calls = [] + + CONFIG["refactorLogger"].info("Finding valid method calls") + + for node in tree.body: + for descendant in node.nodes_of_class(nodes.Call): + if isinstance(descendant.func, nodes.Attribute): + CONFIG["refactorLogger"].debug(f"caller: {descendant.func.expr.as_string()}") + caller = descendant.func.expr # The object calling the method + method_name = descendant.func.attrname + + if method_name != mim_method: + continue + + inferred_types: list[str] = [] + inferrences = caller.infer() + + for inferred in inferrences: + CONFIG["refactorLogger"].debug(f"inferred: {inferred.repr_name()}") + if isinstance(inferred, util.UninferableBase): + hint = check_for_annotations(caller, descendant.scope()) + inits = check_for_initializations(caller, descendant.scope()) + if hint: + inferred_types.append(hint.as_string()) + elif inits: + inferred_types.extend(inits) + else: + continue + else: + inferred_types.append(inferred.repr_name()) + + CONFIG["refactorLogger"].debug(f"Inferred types: {inferred_types}") + + # Check if any inferred type matches a valid class + for cls in inferred_types: + if cls in valid_classes: + CONFIG["refactorLogger"].debug( + f"Foud valid call: {caller.as_string()} at line {descendant.lineno}" + ) + valid_calls.append( + (caller.as_string(), descendant.lineno, method_name, cls) + ) + + return valid_calls + + +def check_for_initializations(caller: nodes.NodeNG, scope: nodes.NodeNG): + inits: list[str] = [] + + for assign in scope.nodes_of_class(nodes.Assign): + if assign.targets[0].as_string() == caller.as_string() and isinstance( + assign.value, nodes.Call + ): + if isinstance(assign.value.func, nodes.Name): + inits.append(assign.value.func.name) + + return inits + + +def check_for_annotations(caller: nodes.NodeNG, scope: nodes.NodeNG): + if not isinstance(scope, nodes.FunctionDef): + return None + + hint = None + CONFIG["refactorLogger"].debug(f"annotations: {scope.args}") + + args = scope.args.args + anns = scope.args.annotations + if args and anns: + for arg, ann in zip(args, anns): + if arg.name == caller.as_string() and ann: + hint = ann + break + + return hint + + +class MakeStaticRefactorer(MultiFileRefactorer[MIMSmell], cst.CSTTransformer): + METADATA_DEPENDENCIES = (PositionProvider,) + + def __init__(self): + super().__init__() + self.target_line = None + self.mim_method_class = "" + self.mim_method = "" + self.valid_classes: set[str] = set() + self.transformer: CallTransformer = None # type: ignore + + def refactor( + self, + target_file: Path, + source_dir: Path, + smell: MIMSmell, + output_file: Path, + overwrite: bool = True, + ): + self.target_line = smell.occurences[0].line + self.target_file = target_file + + if not smell.obj: + raise TypeError("No method object found") + + self.mim_method_class, self.mim_method = smell.obj.split(".") + self.valid_classes.add(self.mim_method_class) + + source_code = target_file.read_text() + tree = MetadataWrapper(cst.parse_module(source_code)) + + # Find all subclasses of the target class + self._find_subclasses(source_dir) + + modified_tree = tree.visit(self) + target_file.write_text(modified_tree.code) + + self.transformer = CallTransformer(self.mim_method_class) + + self.traverse_and_process(source_dir) + if not overwrite: + output_file.write_text(target_file.read_text()) + + def _find_subclasses(self, directory: Path): + """Find all subclasses of the target class within the file.""" + + def get_subclasses(tree: nodes.Module): + subclasses: set[str] = set() + for klass in tree.nodes_of_class(nodes.ClassDef): + if any(base == self.mim_method_class for base in klass.basenames): + if not any(method.name == self.mim_method for method in klass.mymethods()): + subclasses.add(klass.name) + return subclasses + + CONFIG["refactorLogger"].debug("find all subclasses") + self.traverse(directory) + for file in self.py_files: + tree = astroid.parse(file.read_text()) + self.valid_classes = self.valid_classes.union(get_subclasses(tree)) + CONFIG["refactorLogger"].debug(f"valid classes: {self.valid_classes}") + + def _process_file(self, file: Path): + processed = False + + source_code = file.read_text("utf-8") + + astroid_tree = astroid.parse(source_code) + valid_calls = find_valid_method_calls(astroid_tree, self.mim_method, self.valid_classes) + self.transformer.set_calls(valid_calls) + + tree = MetadataWrapper(cst.parse_module(source_code)) + modified_tree = tree.visit(self.transformer) + + if self.transformer.transformed: + file.write_text(modified_tree.code) + if not file.samefile(self.target_file): + processed = True + self.transformer.transformed = False + + return processed + + def leave_FunctionDef( + self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef + ) -> cst.FunctionDef: + func_name = original_node.name.value + if func_name and updated_node.deep_equals(original_node): + position = self.get_metadata(PositionProvider, original_node).start # type: ignore + if position.line == self.target_line and func_name == self.mim_method: + CONFIG["refactorLogger"].debug("Modifying MIM method") + decorators = [ + *list(original_node.decorators), + cst.Decorator(cst.Name("staticmethod")), + ] + params = original_node.params + if params.params and params.params[0].name.value == "self": + params = params.with_changes(params=params.params[1:]) + return updated_node.with_changes(decorators=decorators, params=params) + return updated_node diff --git a/src/ecooptimizer/refactorers/concrete/repeated_calls.py b/src/ecooptimizer/refactorers/concrete/repeated_calls.py new file mode 100644 index 00000000..d45db02d --- /dev/null +++ b/src/ecooptimizer/refactorers/concrete/repeated_calls.py @@ -0,0 +1,169 @@ +import ast +import re +from pathlib import Path + +from ...data_types.smell import CRCSmell +from ..base_refactorer import BaseRefactorer + + +def extract_function_name(call_string: str): + """Extracts a specific function/method name from a call string.""" + match = re.match(r"(\w+)\.(\w+)\s*\(", call_string) # Match `obj.method()` + if match: + return f"{match.group(1)}_{match.group(2)}" # Format: cache_obj_method + match = re.match(r"(\w+)\s*\(", call_string) # Match `function()` + if match: + return f"{match.group(1)}" # Format: cache_function + return call_string # Fallback (shouldn't happen in valid calls) + + +class CacheRepeatedCallsRefactorer(BaseRefactorer[CRCSmell]): + def __init__(self): + """ + Initializes the CacheRepeatedCallsRefactorer. + """ + super().__init__() + self.target_line = None + + def refactor( + self, + target_file: Path, + source_dir: Path, # noqa: ARG002 + smell: CRCSmell, + output_file: Path, + overwrite: bool = True, + ): + """ + Refactor the repeated function call smell and save to a new file. + """ + self.target_file = target_file + self.smell = smell + self.call_string = self.smell.additionalInfo.callString.strip() + + # Correctly generate cached variable name + self.cached_var_name = "cached_" + extract_function_name(self.call_string) + + with self.target_file.open("r") as file: + lines = file.readlines() + + # Parse the AST + tree = ast.parse("".join(lines)) + + # Find the valid parent node + parent_node = self._find_valid_parent(tree) + if not parent_node: + return + + # Determine the insertion point for the cached variable + insert_line = self._find_insert_line(parent_node) + indent = self._get_indentation(lines, insert_line) + cached_assignment = f"{indent}{self.cached_var_name} = {self.call_string}\n" + + # Insert the cached variable into the source lines + lines.insert(insert_line - 1, cached_assignment) + line_shift = 1 # Track the shift in line numbers caused by the insertion + + # Replace calls with the cached variable in the affected lines + for occurrence in self.smell.occurences: + adjusted_line_index = occurrence.line - 1 + line_shift + original_line = lines[adjusted_line_index] + updated_line = self._replace_call_in_line( + original_line, self.call_string, self.cached_var_name + ) + if updated_line != original_line: + lines[adjusted_line_index] = updated_line + + # Multi-file implementation + if overwrite: + with target_file.open("w") as f: + f.writelines(lines) + else: + with output_file.open("w") as f: + f.writelines(lines) + + def _get_indentation(self, lines: list[str], line_number: int): + """Determine the indentation level of a given line.""" + line = lines[line_number - 1] + return line[: len(line) - len(line.lstrip())] + + def _replace_call_in_line(self, line: str, call_string: str, cached_var_name: str): + """ + Replace the repeated call in a line with the cached variable. + """ + return line.replace(call_string, cached_var_name) + + def _find_valid_parent(self, tree: ast.Module): + """ + Find the valid parent node that contains all occurrences of the repeated call. + """ + candidate_parent = None + for node in ast.walk(tree): + if isinstance(node, (ast.FunctionDef, ast.ClassDef, ast.Module)): + if all(self._line_in_node_body(node, occ.line) for occ in self.smell.occurences): + candidate_parent = node + return candidate_parent + + def _find_insert_line(self, parent_node: ast.FunctionDef | ast.ClassDef | ast.Module): + """ + Find the line to insert the cached variable assignment. + + - If it's a function, insert at the beginning but **after a docstring** if present. + - If it's a method call (`obj.method()`), insert after `obj` is defined. + - If it's a lambda assignment (`compute_demo = lambda ...`), insert after it. + """ + if isinstance(parent_node, ast.Module): + return 1 # Top of the module + + # Extract variable or function name from call string + var_match = re.match(r"(\w+)\.", self.call_string) # Matches `obj.method()` + if var_match: + obj_name = var_match.group(1) # Extract `obj` + + # Find the first assignment of `obj` + for node in parent_node.body: + if isinstance(node, ast.Assign): + if any( + isinstance(target, ast.Name) and target.id == obj_name + for target in node.targets + ): + return node.lineno + 1 # Insert after the assignment of `obj` + + # Find the first lambda assignment + for node in parent_node.body: + if isinstance(node, ast.Assign) and isinstance(node.value, ast.Lambda): + lambda_var_name = node.targets[0].id # Extract variable name + if lambda_var_name in self.call_string: + return node.lineno + 1 # Insert after the lambda function + + # Check if the first statement is a docstring + if ( + isinstance(parent_node.body[0], ast.Expr) + and isinstance(parent_node.body[0].value, ast.Constant) + and isinstance(parent_node.body[0].value.value, str) # Ensures it's a string docstring + ): + docstring_start = parent_node.body[0].lineno + docstring_end = docstring_start + + # Find the last line of the docstring by counting the lines it spans + docstring_content = parent_node.body[0].value.value + docstring_lines = docstring_content.count("\n") + if docstring_lines > 0: + docstring_end += docstring_lines + + return docstring_end + 1 # Insert after the last line of the docstring + + return parent_node.body[0].lineno # Default: insert at function start + + def _line_in_node_body(self, node: ast.FunctionDef | ast.ClassDef | ast.Module, line: int): + """ + Check if a line is within the body of a given AST node. + """ + if not hasattr(node, "body"): + return False + + for child in node.body: + if hasattr(child, "lineno") and child.lineno <= line <= getattr( + child, "end_lineno", child.lineno + ): + return True + return False diff --git a/src/ecooptimizer/refactorers/concrete/str_concat_in_loop.py b/src/ecooptimizer/refactorers/concrete/str_concat_in_loop.py new file mode 100644 index 00000000..e4575844 --- /dev/null +++ b/src/ecooptimizer/refactorers/concrete/str_concat_in_loop.py @@ -0,0 +1,303 @@ +import re + +from pathlib import Path +import astroid +from astroid import nodes + +from ..base_refactorer import BaseRefactorer +from ...data_types.smell import SCLSmell + + +class UseListAccumulationRefactorer(BaseRefactorer[SCLSmell]): + """ + Refactorer that targets string concatenations inside loops + """ + + def __init__(self): + super().__init__() + self.target_lines: list[int] = [] + self.assign_var = "" + self.target_node: nodes.NodeNG = None + self.last_assign_node: nodes.Assign | nodes.AugAssign = None # type: ignore + self.concat_nodes: list[nodes.Assign | nodes.AugAssign] = [] + self.reassignments: list[nodes.Assign] = [] + self.outer_loop_line: int = 0 + self.outer_loop: nodes.For | nodes.While = None # type: ignore + + def reset(self): + self.__init__() + + def refactor( + self, + target_file: Path, + source_dir: Path, # noqa: ARG002 + smell: SCLSmell, + output_file: Path, + overwrite: bool = True, + ): + """ + Refactor string concatenations in loops to use list accumulation and join + + :param target_file: absolute path to source code + :param smell: pylint code for smell + :param initial_emission: inital carbon emission prior to refactoring + """ + self.target_lines = [occ.line for occ in smell.occurences] + + if not smell.additionalInfo: + raise RuntimeError("Missing additional info for 'string-concat-loop' smell") + + self.assign_var = smell.additionalInfo.concatTarget + self.outer_loop_line = smell.additionalInfo.innerLoopLine + + # Parse the code into an AST + source_code = target_file.read_text() + tree = astroid.parse(source_code) + for node in tree.get_children(): + self.visit(node) + + if not self.outer_loop or len(self.concat_nodes) != len(self.target_lines): + raise Exception("Missing inner loop or concat nodes.") + + self.find_reassignments() + self.find_scope() + + temp_concat_nodes = [("concat", node) for node in self.concat_nodes] + temp_reassignments = [("reassign", node) for node in self.reassignments] + + combined_nodes = temp_concat_nodes + temp_reassignments + + combined_nodes = sorted( + combined_nodes, + key=lambda x: x[1].lineno, # type: ignore + reverse=True, + ) + + modified_code = self.add_node_to_body(source_code, combined_nodes) + + if overwrite: + target_file.write_text(modified_code) + else: + output_file.write_text(modified_code) + + def visit(self, node: nodes.NodeNG): + if isinstance(node, nodes.Assign) and node.lineno in self.target_lines: + if not self.target_node: + self.target_node = node.targets[0] + self.concat_nodes.append(node) + elif isinstance(node, nodes.AugAssign) and node.lineno in self.target_lines: + if not self.target_node: + self.target_node = node.target + self.concat_nodes.append(node) + elif isinstance(node, (nodes.For, nodes.While)) and node.lineno == self.outer_loop_line: + self.outer_loop = node + for child in node.get_children(): + self.visit(child) + else: + for child in node.get_children(): + self.visit(child) + + def find_reassignments(self): + for node in self.outer_loop.nodes_of_class(nodes.Assign): + for target in node.targets: + if target.as_string() == self.assign_var and node.lineno not in self.target_lines: + self.reassignments.append(node) + + def find_last_assignment(self, scope_node: nodes.NodeNG): + """Find the last assignment of the target variable within a given scope node.""" + last_assignment_node = None + + # Traverse the scope node and find assignments within the valid range + for node in scope_node.nodes_of_class((nodes.AugAssign, nodes.Assign)): + if isinstance(node, nodes.Assign): + for target in node.targets: + if ( + target.as_string() == self.assign_var + and node.lineno < self.outer_loop.lineno # type: ignore + ): + if last_assignment_node is None: + last_assignment_node = node + elif node.lineno > last_assignment_node.lineno: # type: ignore + last_assignment_node = node + else: + if ( + node.target.as_string() == self.assign_var + and node.lineno < self.outer_loop.lineno # type: ignore + ): + if last_assignment_node is None: + last_assignment_node = node + elif node.lineno > last_assignment_node.lineno: # type: ignore + last_assignment_node = node + + self.last_assign_node = last_assignment_node # type: ignore + + def find_scope(self): + """Locate the second innermost loop if nested, else find first non-loop function/method/module ancestor.""" + + for node in self.outer_loop.node_ancestors(): + if isinstance(node, (nodes.For, nodes.While)): + self.find_last_assignment(node) + if not self.last_assign_node: + self.outer_loop = node + else: + self.scope_node = node + break + elif isinstance(node, (nodes.Module, nodes.FunctionDef, nodes.AsyncFunctionDef)): + self.find_last_assignment(node) + self.scope_node = node + break + + def last_assign_is_referenced(self, search_area: str): + return ( + search_area.find(self.assign_var) != -1 + or isinstance(self.last_assign_node, nodes.AugAssign) + or self.assign_var in self.last_assign_node.value.as_string() + ) + + def generate_temp_list_name(self): + node = self.target_node + + def _get_node_representation(node: nodes.NodeNG): + """Helper function to get a string representation of a node.""" + if isinstance(node, astroid.Const): + return str(node.value) + if isinstance(node, astroid.Name): + return node.name + if isinstance(node, astroid.Attribute): + return node.attrname + return "unknown" + + if isinstance(node, astroid.Subscript): + # Extracting slice and value for a Subscript node + slice_repr = _get_node_representation(node.slice) + value_repr = _get_node_representation(node.value) + custom_component = f"{value_repr}_at_{slice_repr}" + elif isinstance(node, astroid.AssignAttr): + # Extracting attribute name for an AssignAttr node + attribute_name = node.attrname + custom_component = attribute_name + else: + raise TypeError("Node must be either Subscript or AssignAttr.") + + return f"temp_{custom_component}" + + def add_node_to_body(self, code_file: str, nodes_to_change: list[tuple]): # type: ignore + """ + Add a new AST node + """ + + code_file_lines = code_file.splitlines() + + list_name = self.assign_var + + if not isinstance(self.target_node, nodes.AssignName): + list_name = self.generate_temp_list_name() + + # ------------- ADD JOIN STATEMENT TO SOURCE ---------------- + + join_line = f"{self.assign_var} = ''.join({list_name})" + indent_lno: int = self.outer_loop.lineno - 1 # type: ignore + join_lno: int = self.outer_loop.end_lineno # type: ignore + + source_line = code_file_lines[indent_lno] + outer_scope_whitespace = source_line[: len(source_line) - len(source_line.lstrip())] + + code_file_lines.insert(join_lno, outer_scope_whitespace + join_line) + + def get_new_concat_line(concat_node: nodes.AugAssign | nodes.Assign): + concat_line = "" + if isinstance(concat_node, nodes.AugAssign): + concat_line = f"{list_name}.append({concat_node.value.as_string()})" + else: + parts = re.split( + rf"\s*[+]*\s*\b{re.escape(self.assign_var)}\b\s*[+]*\s*", + concat_node.value.as_string(), + ) + + if len(parts[0]) == 0: + concat_line = f"{list_name}.append({parts[1]})" + elif len(parts[1]) == 0: + concat_line = f"{list_name}.insert(0, {parts[0]})" + else: + concat_line = [ + f"{list_name}.insert(0, {parts[0]})", + f"{list_name}.append({parts[1]})", + ] + return concat_line + + def get_new_reassign_line(reassign_node: nodes.Assign): + if reassign_node.value.as_string() in ["''", '""', "str()"]: + return f"{list_name}.clear()" + else: + return f"{list_name} = [{reassign_node.value.as_string()}]" + + # ------------- REFACTOR CONCATS and REASSIGNS ---------------------------- + + for node in nodes_to_change: + if node[0] == "concat": + new_concat = get_new_concat_line(node[1]) + concat_lno = node[1].lineno - 1 + + if isinstance(new_concat, list): + source_line = code_file_lines[concat_lno] + concat_whitespace = source_line[: len(source_line) - len(source_line.lstrip())] + + code_file_lines.pop(concat_lno) + code_file_lines.insert(concat_lno, concat_whitespace + new_concat[1]) + code_file_lines.insert(concat_lno, concat_whitespace + new_concat[0]) + else: + source_line = code_file_lines[concat_lno] + concat_whitespace = source_line[: len(source_line) - len(source_line.lstrip())] + + code_file_lines.pop(concat_lno) + code_file_lines.insert(concat_lno, concat_whitespace + new_concat) + else: + new_reassign = get_new_reassign_line(node[1]) + reassign_lno = node[1].lineno - 1 + + source_line = code_file_lines[reassign_lno] + reassign_whitespace = source_line[: len(source_line) - len(source_line.lstrip())] + + code_file_lines.pop(reassign_lno) + code_file_lines.insert(reassign_lno, reassign_whitespace + new_reassign) + + # ------------- INITIALIZE TARGET VAR AS A LIST ------------- + if ( + not isinstance(self.target_node, nodes.AssignName) + or not self.last_assign_node + or self.last_assign_is_referenced( + "".join(code_file_lines[self.last_assign_node.lineno : self.outer_loop.lineno - 1]) # type: ignore + ) + ): + list_lno: int = self.outer_loop.lineno - 1 # type: ignore + + source_line = code_file_lines[list_lno] + outer_scope_whitespace = source_line[: len(source_line) - len(source_line.lstrip())] + + list_line = f"{list_name} = [{self.assign_var}]" + + code_file_lines.insert(list_lno, outer_scope_whitespace + list_line) + + elif self.last_assign_node.value.as_string() in ["''", "str()"]: + list_lno: int = self.last_assign_node.lineno - 1 # type: ignore + + source_line = code_file_lines[list_lno] + outer_scope_whitespace = source_line[: len(source_line) - len(source_line.lstrip())] + + list_line = f"{list_name} = []" + + code_file_lines.pop(list_lno) + code_file_lines.insert(list_lno, outer_scope_whitespace + list_line) + + else: + list_lno: int = self.last_assign_node.lineno - 1 # type: ignore + + source_line = code_file_lines[list_lno] + outer_scope_whitespace = source_line[: len(source_line) - len(source_line.lstrip())] + + list_line = f"{list_name} = [{self.last_assign_node.value.as_string()}]" + + code_file_lines.pop(list_lno) + code_file_lines.insert(list_lno, outer_scope_whitespace + list_line) + + return "\n".join(code_file_lines) diff --git a/src/ecooptimizer/refactorers/multi_file_refactorer.py b/src/ecooptimizer/refactorers/multi_file_refactorer.py new file mode 100644 index 00000000..f5ee57e0 --- /dev/null +++ b/src/ecooptimizer/refactorers/multi_file_refactorer.py @@ -0,0 +1,80 @@ +# pyright: reportOptionalMemberAccess=false +from abc import abstractmethod +import fnmatch +from pathlib import Path +from typing import TypeVar + +from ..config import CONFIG + +from .base_refactorer import BaseRefactorer + +from ..data_types.smell import Smell + + +T = TypeVar("T", bound=Smell) + +DEFAULT_IGNORED_PATTERNS = { + "__pycache__", + "build", + ".venv", + "*.egg-info", + ".git", + "node_modules", + ".*", +} + +DEFAULT_IGNORE_PATH = Path(__file__).parent / "patterns_to_ignore" + + +class MultiFileRefactorer(BaseRefactorer[T]): + def __init__(self): + super().__init__() + self.target_file: Path = None # type: ignore + self.ignore_patterns = self._load_ignore_patterns() + self.py_files: list[Path] = [] + + def _load_ignore_patterns(self, ignore_dir: Path = DEFAULT_IGNORE_PATH) -> set[str]: + """Load ignore patterns from a file, similar to .gitignore.""" + if not ignore_dir.is_dir(): + return DEFAULT_IGNORED_PATTERNS + + patterns = DEFAULT_IGNORED_PATTERNS + for file in ignore_dir.iterdir(): + with file.open() as f: + patterns.update( + [line.strip() for line in f if line.strip() and not line.startswith("#")] + ) + + return patterns + + def is_ignored(self, item: Path) -> bool: + """Check if a file or directory matches any ignore pattern.""" + return any(fnmatch.fnmatch(item.name, pattern) for pattern in self.ignore_patterns) + + def traverse(self, directory: Path): + for item in directory.iterdir(): + if item.is_dir(): + CONFIG["refactorLogger"].debug(f"Scanning directory: {item!s}, name: {item.name}") + if self.is_ignored(item): + CONFIG["refactorLogger"].debug(f"Ignored directory: {item!s}") + continue + + CONFIG["refactorLogger"].debug(f"Entering directory: {item!s}") + self.traverse_and_process(item) + elif item.is_file() and item.suffix == ".py": + self.py_files.append(item) + + def traverse_and_process(self, directory: Path): + if not self.py_files: + self.traverse(directory) + for file in self.py_files: + CONFIG["refactorLogger"].debug(f"Checking file: {file!s}") + if self._process_file(file): + if file not in self.modified_files and not file.samefile(self.target_file): + self.modified_files.append(file.resolve()) + CONFIG["refactorLogger"].debug("finished processing file") + + @abstractmethod + def _process_file(self, file: Path) -> bool: + """Abstract method to be implemented by subclasses to handle file processing.""" + pass diff --git a/src/ecooptimizer/refactorers/patterns_to_ignore/.generalignore b/src/ecooptimizer/refactorers/patterns_to_ignore/.generalignore new file mode 100644 index 00000000..e36e56d3 --- /dev/null +++ b/src/ecooptimizer/refactorers/patterns_to_ignore/.generalignore @@ -0,0 +1,32 @@ +# Build and distribution artifacts +*.whl + +# IDE and editor files +.vscode/ +.idea/ +*.sublime-* + +# Version control and OS metadata +.git/ +.gitignore +.gitattributes +.svn/ +.DS_Store +Thumbs.db + +# Containerisation and deployment +Dockerfile +.dockerignore +.env +*.log + +# Dependency managers and tooling +poetry.lock +pyproject.toml +requirements.txt +*.ipynb_checkpoints/ + +# Hidden files and miscellaneous patterns +.* +*.bak +*.swp diff --git a/src/ecooptimizer/refactorers/patterns_to_ignore/.pythonignore b/src/ecooptimizer/refactorers/patterns_to_ignore/.pythonignore new file mode 100644 index 00000000..1800114d --- /dev/null +++ b/src/ecooptimizer/refactorers/patterns_to_ignore/.pythonignore @@ -0,0 +1,174 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# UV +# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +#uv.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# Ruff stuff: +.ruff_cache/ + +# PyPI configuration file +.pypirc \ No newline at end of file diff --git a/src/ecooptimizer/refactorers/refactorer_controller.py b/src/ecooptimizer/refactorers/refactorer_controller.py new file mode 100644 index 00000000..214dd29d --- /dev/null +++ b/src/ecooptimizer/refactorers/refactorer_controller.py @@ -0,0 +1,54 @@ +# pyright: reportOptionalMemberAccess=false +from pathlib import Path + +from ..config import CONFIG + +from ..data_types.smell import Smell +from ..utils.smells_registry import get_refactorer + + +class RefactorerController: + def __init__(self): + """Manages the execution of refactorers for detected code smells.""" + self.smell_counters = {} + + def run_refactorer( + self, target_file: Path, source_dir: Path, smell: Smell, overwrite: bool = True + ): + """Executes the appropriate refactorer for the given smell. + + Args: + target_file (Path): The file to be refactored. + source_dir (Path): The source directory containing the file. + smell (Smell): The detected smell to be refactored. + overwrite (bool, optional): Whether to overwrite existing files. Defaults to True. + + Returns: + list[Path]: A list of modified files resulting from the refactoring process. + + Raises: + NotImplementedError: If no refactorer exists for the given smell. + """ + smell_id = smell.messageId + smell_symbol = smell.symbol + refactorer_class = get_refactorer(smell_symbol) + modified_files = [] + + if refactorer_class: + self.smell_counters[smell_id] = self.smell_counters.get(smell_id, 0) + 1 + file_count = self.smell_counters[smell_id] + + output_file_name = f"{target_file.stem}_path_{smell_id}_{file_count}.py" + output_file_path = Path(__file__).parent / "../../../outputs" / output_file_name + + CONFIG["refactorLogger"].info( + f"πŸ”„ Running refactoring for {smell_symbol} using {refactorer_class.__name__}" + ) + refactorer = refactorer_class() + refactorer.refactor(target_file, source_dir, smell, output_file_path, overwrite) + modified_files = refactorer.modified_files + else: + CONFIG["refactorLogger"].error(f"❌ No refactorer found for smell: {smell_symbol}") + raise NotImplementedError(f"No refactorer implemented for smell: {smell_symbol}") + + return modified_files diff --git a/src/ecooptimizer/utils/__init__.py b/src/ecooptimizer/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/ecooptimizer/utils/output_manager.py b/src/ecooptimizer/utils/output_manager.py new file mode 100644 index 00000000..8ba2539e --- /dev/null +++ b/src/ecooptimizer/utils/output_manager.py @@ -0,0 +1,124 @@ +from enum import Enum +import json +import logging +from pathlib import Path +import shutil +from typing import Any + + +DEV_OUTPUT = Path(__file__).parent / "../../../outputs" + + +class EnumEncoder(json.JSONEncoder): + def default(self, o): # noqa: ANN001 + if isinstance(o, Enum): + return o.value # Serialize using the Enum's value + return super().default(o) + + +class LoggingManager: + def __init__(self, logs_dir: Path = DEV_OUTPUT / "logs", production: bool = False): + """Initializes log paths based on mode.""" + + self.production = production + self.logs_dir = logs_dir + + self._initialize_output_structure() + self.log_files = { + "main": self.logs_dir / "main.log", + "detect": self.logs_dir / "detect.log", + "refactor": self.logs_dir / "refactor.log", + } + self._setup_loggers() + + def _initialize_output_structure(self): + """Ensures required directories exist and clears old logs.""" + if not self.production: + DEV_OUTPUT.mkdir(exist_ok=True) + self.logs_dir.mkdir(exist_ok=True) + + def _clear_logs(self): + """Removes existing log files while preserving the log directory.""" + if self.logs_dir.exists(): + for log_file in self.logs_dir.iterdir(): + if log_file.is_file(): + log_file.unlink() + logging.info("πŸ—‘οΈ Cleared existing log files.") + + def _setup_loggers(self): + """Configures loggers for different EcoOptimizer processes.""" + logging.root.handlers.clear() + + logging.basicConfig( + filename=str(self.log_files["main"]), + filemode="a", + level=logging.INFO, + format="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + force=True, + ) + + self.loggers = { + "detect": self._create_logger( + "detect", self.log_files["detect"], self.log_files["main"] + ), + "refactor": self._create_logger( + "refactor", self.log_files["refactor"], self.log_files["main"] + ), + } + + logging.info("πŸ“ Loggers initialized successfully.") + + def _create_logger(self, name: str, log_file: Path, main_log_file: Path): + """ + Creates a logger that logs to both its own file and the main log file. + + Args: + name (str): Name of the logger. + log_file (Path): Path to the specific log file. + main_log_file (Path): Path to the main log file. + + Returns: + logging.Logger: Configured logger instance. + """ + logger = logging.getLogger(name) + logger.setLevel(logging.INFO) + logger.propagate = False + + file_handler = logging.FileHandler(str(log_file), mode="a", encoding="utf-8") + formatter = logging.Formatter( + "%(asctime)s.%(msecs)03d [%(levelname)s] %(message)s", "%Y-%m-%d %H:%M:%S" + ) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + + main_handler = logging.FileHandler(str(main_log_file), mode="a", encoding="utf-8") + main_handler.setFormatter(formatter) + logger.addHandler(main_handler) + + logging.info(f"πŸ“ Logger '{name}' initialized and writing to {log_file}.") + return logger + + +def save_file(file_name: str, data: str, mode: str, message: str = ""): + """Saves data to a file in the output directory.""" + file_path = DEV_OUTPUT / file_name + with file_path.open(mode) as file: + file.write(data) + log_message = message if message else f"πŸ“ {file_name} saved to {file_path!s}" + logging.info(log_message) + + +def save_json_files(file_name: str, data: dict[Any, Any] | list[Any]): + """Saves data to a JSON file in the output directory.""" + file_path = DEV_OUTPUT / file_name + file_path.write_text(json.dumps(data, cls=EnumEncoder, sort_keys=True, indent=4)) + logging.info(f"πŸ“ {file_name} saved to {file_path!s} as JSON file") + + +def copy_file_to_output(source_file_path: Path, new_file_name: str): + """Copies a file to the output directory with a new name.""" + destination_path = DEV_OUTPUT / new_file_name + shutil.copy(source_file_path, destination_path) + logging.info(f"πŸ“ {new_file_name} copied to {destination_path!s}") + return destination_path diff --git a/src/ecooptimizer/utils/smell_enums.py b/src/ecooptimizer/utils/smell_enums.py new file mode 100644 index 00000000..3661002e --- /dev/null +++ b/src/ecooptimizer/utils/smell_enums.py @@ -0,0 +1,29 @@ +from enum import Enum + + +class ExtendedEnum(Enum): + @classmethod + def list(cls) -> list[str]: + return [c.value for c in cls] + + def __eq__(self, value: object) -> bool: + return str(self.value) == value + + +# Enum class for standard Pylint code smells +class PylintSmell(ExtendedEnum): + LONG_PARAMETER_LIST = "R0913" # Pylint code smell for functions with too many parameters + NO_SELF_USE = "R6301" # Pylint code smell for class methods that don't use any self calls + USE_A_GENERATOR = ( + "R1729" # Pylint code smell for unnecessary list comprehensions inside `any()` or `all()` + ) + + +# Enum class for custom code smells not detected by Pylint +class CustomSmell(ExtendedEnum): + LONG_MESSAGE_CHAIN = "LMC001" # Ast code smell for long message chains + UNUSED_VAR_OR_ATTRIBUTE = "UVA001" # Ast code smell for unused variable or attribute + LONG_ELEMENT_CHAIN = "LEC001" # Ast code smell for long element chains + LONG_LAMBDA_EXPR = "LLE001" # Ast code smell for long lambda expressions + STR_CONCAT_IN_LOOP = "SCL001" # Astroid code smell for string concatenation inside loops + CACHE_REPEATED_CALLS = "CRC001" # Ast code smell for repeated calls diff --git a/src/ecooptimizer/utils/smells_registry.py b/src/ecooptimizer/utils/smells_registry.py new file mode 100644 index 00000000..0de8fe82 --- /dev/null +++ b/src/ecooptimizer/utils/smells_registry.py @@ -0,0 +1,100 @@ +from copy import deepcopy +from .smell_enums import CustomSmell, PylintSmell + +from ..analyzers.ast_analyzers.detect_long_element_chain import detect_long_element_chain +from ..analyzers.ast_analyzers.detect_long_lambda_expression import detect_long_lambda_expression +from ..analyzers.ast_analyzers.detect_long_message_chain import detect_long_message_chain +from ..analyzers.astroid_analyzers.detect_string_concat_in_loop import detect_string_concat_in_loop +from ..analyzers.ast_analyzers.detect_repeated_calls import detect_repeated_calls + +from ..refactorers.concrete.list_comp_any_all import UseAGeneratorRefactorer + +from ..refactorers.concrete.long_lambda_function import LongLambdaFunctionRefactorer +from ..refactorers.concrete.long_element_chain import LongElementChainRefactorer +from ..refactorers.concrete.long_message_chain import LongMessageChainRefactorer +from ..refactorers.concrete.member_ignoring_method import MakeStaticRefactorer +from ..refactorers.concrete.long_parameter_list import LongParameterListRefactorer +from ..refactorers.concrete.str_concat_in_loop import UseListAccumulationRefactorer +from ..refactorers.concrete.repeated_calls import CacheRepeatedCallsRefactorer + +from ..data_types.smell_record import SmellRecord + +_SMELL_REGISTRY: dict[str, SmellRecord] = { + "use-a-generator": { + "id": PylintSmell.USE_A_GENERATOR.value, + "enabled": True, + "analyzer_method": "pylint", + "checker": None, + "analyzer_options": {}, + "refactorer": UseAGeneratorRefactorer, + }, + "too-many-arguments": { + "id": PylintSmell.LONG_PARAMETER_LIST.value, + "enabled": True, + "analyzer_method": "pylint", + "checker": None, + "analyzer_options": {"max_args": {"flag": "--max-args", "value": 6}}, + "refactorer": LongParameterListRefactorer, + }, + "no-self-use": { + "id": PylintSmell.NO_SELF_USE.value, + "enabled": True, + "analyzer_method": "pylint", + "checker": None, + "analyzer_options": { + "load-plugin": {"flag": "--load-plugins", "value": "pylint.extensions.no_self_use"} + }, + "refactorer": MakeStaticRefactorer, + }, + "long-lambda-expression": { + "id": CustomSmell.LONG_LAMBDA_EXPR.value, + "enabled": True, + "analyzer_method": "ast", + "checker": detect_long_lambda_expression, + "analyzer_options": {"threshold_length": 100, "threshold_count": 5}, + "refactorer": LongLambdaFunctionRefactorer, + }, + "long-message-chain": { + "id": CustomSmell.LONG_MESSAGE_CHAIN.value, + "enabled": True, + "analyzer_method": "ast", + "checker": detect_long_message_chain, + "analyzer_options": {"threshold": 3}, + "refactorer": LongMessageChainRefactorer, + }, + "long-element-chain": { + "id": CustomSmell.LONG_ELEMENT_CHAIN.value, + "enabled": True, + "analyzer_method": "ast", + "checker": detect_long_element_chain, + "analyzer_options": {"threshold": 3}, + "refactorer": LongElementChainRefactorer, + }, + "cached-repeated-calls": { + "id": CustomSmell.CACHE_REPEATED_CALLS.value, + "enabled": True, + "analyzer_method": "ast", + "checker": detect_repeated_calls, + "analyzer_options": {"threshold": 2}, + "refactorer": CacheRepeatedCallsRefactorer, + }, + "string-concat-loop": { + "id": CustomSmell.STR_CONCAT_IN_LOOP.value, + "enabled": True, + "analyzer_method": "astroid", + "checker": detect_string_concat_in_loop, + "analyzer_options": {}, + "refactorer": UseListAccumulationRefactorer, + }, +} + + +def retrieve_smell_registry(enabled_smells: list[str] | str): + """Returns a modified SMELL_REGISTRY based on user preferences (enables/disables smells).""" + if enabled_smells == "ALL": + return deepcopy(_SMELL_REGISTRY) + return {key: val for (key, val) in _SMELL_REGISTRY.items() if key in enabled_smells} + + +def get_refactorer(symbol: str): + return _SMELL_REGISTRY[symbol].get("refactorer", None) diff --git a/src/main.py b/src/main.py deleted file mode 100644 index 4508a68d..00000000 --- a/src/main.py +++ /dev/null @@ -1,15 +0,0 @@ -from analyzers.pylint_analyzer import PylintAnalyzer - -def main(): - """ - Entry point for the refactoring tool. - - Create an instance of the analyzer. - - Perform code analysis and print the results. - """ - code_path = "path/to/your/code" # Path to the code to analyze - analyzer = PylintAnalyzer(code_path) - report = analyzer.analyze() # Analyze the code - print(report) # Print the analysis report - -if __name__ == "__main__": - main() diff --git a/src/measurement/energy_meter.py b/src/measurement/energy_meter.py deleted file mode 100644 index 8d589d9d..00000000 --- a/src/measurement/energy_meter.py +++ /dev/null @@ -1,59 +0,0 @@ -import time -from typing import Callable -import pyJoules.energy as joules - -class EnergyMeter: - """ - A class to measure the energy consumption of specific code blocks using PyJoules. - """ - - def __init__(self): - """ - Initializes the EnergyMeter class. - """ - # Optional: Any initialization for the energy measurement can go here - pass - - def measure_energy(self, func: Callable, *args, **kwargs): - """ - Measures the energy consumed by the specified function during its execution. - - Parameters: - - func (Callable): The function to measure. - - *args: Arguments to pass to the function. - - **kwargs: Keyword arguments to pass to the function. - - Returns: - - tuple: A tuple containing the return value of the function and the energy consumed (in Joules). - """ - start_energy = joules.getEnergy() # Start measuring energy - start_time = time.time() # Record start time - - result = func(*args, **kwargs) # Call the specified function - - end_time = time.time() # Record end time - end_energy = joules.getEnergy() # Stop measuring energy - - energy_consumed = end_energy - start_energy # Calculate energy consumed - - # Log the timing (optional) - print(f"Execution Time: {end_time - start_time:.6f} seconds") - print(f"Energy Consumed: {energy_consumed:.6f} Joules") - - return result, energy_consumed # Return the result of the function and the energy consumed - - def measure_block(self, code_block: str): - """ - Measures energy consumption for a block of code represented as a string. - - Parameters: - - code_block (str): A string containing the code to execute. - - Returns: - - float: The energy consumed (in Joules). - """ - local_vars = {} - exec(code_block, {}, local_vars) # Execute the code block - energy_consumed = joules.getEnergy() # Measure energy after execution - print(f"Energy Consumed for the block: {energy_consumed:.6f} Joules") - return energy_consumed diff --git a/src/refactorer/base_refactorer.py b/src/refactorer/base_refactorer.py deleted file mode 100644 index 698440fb..00000000 --- a/src/refactorer/base_refactorer.py +++ /dev/null @@ -1,24 +0,0 @@ -# src/refactorer/base_refactorer.py - -from abc import ABC, abstractmethod - -class BaseRefactorer(ABC): - """ - Abstract base class for refactorers. - Subclasses should implement the `refactor` method. - """ - - def __init__(self, code): - """ - Initialize the refactorer with the code to refactor. - - :param code: The code that needs refactoring - """ - self.code = code - - def refactor(self): - """ - Perform the refactoring process. - Must be implemented by subclasses. - """ - raise NotImplementedError("Subclasses should implement this method") diff --git a/src/refactorer/complex_list_comprehension_refactorer.py b/src/refactorer/complex_list_comprehension_refactorer.py deleted file mode 100644 index b4a96586..00000000 --- a/src/refactorer/complex_list_comprehension_refactorer.py +++ /dev/null @@ -1,115 +0,0 @@ -import ast -import astor - -class ComplexListComprehensionRefactorer: - """ - Refactorer for complex list comprehensions to improve readability. - """ - - def __init__(self, code: str): - """ - Initializes the refactorer. - - :param code: The source code to refactor. - """ - self.code = code - - def refactor(self): - """ - Refactor the code by transforming complex list comprehensions into for-loops. - - :return: The refactored code. - """ - # Parse the code to get the AST - tree = ast.parse(self.code) - - # Walk through the AST and refactor complex list comprehensions - for node in ast.walk(tree): - if isinstance(node, ast.ListComp): - # Check if the list comprehension is complex - if self.is_complex(node): - # Create a for-loop equivalent - for_loop = self.create_for_loop(node) - # Replace the list comprehension with the for-loop in the AST - self.replace_node(node, for_loop) - - # Convert the AST back to code - return self.ast_to_code(tree) - - def create_for_loop(self, list_comp: ast.ListComp) -> ast.For: - """ - Create a for-loop that represents the list comprehension. - - :param list_comp: The ListComp node to convert. - :return: An ast.For node representing the for-loop. - """ - # Create the variable to hold results - result_var = ast.Name(id='result', ctx=ast.Store()) - - # Create the for-loop - for_loop = ast.For( - target=ast.Name(id='item', ctx=ast.Store()), - iter=list_comp.generators[0].iter, - body=[ - ast.Expr(value=ast.Call( - func=ast.Name(id='append', ctx=ast.Load()), - args=[self.transform_value(list_comp.elt)], - keywords=[] - )) - ], - orelse=[] - ) - - # Create a list to hold results - result_list = ast.List(elts=[], ctx=ast.Store()) - return ast.With( - context_expr=ast.Name(id='result', ctx=ast.Load()), - body=[for_loop], - lineno=list_comp.lineno, - col_offset=list_comp.col_offset - ) - - def transform_value(self, value_node: ast.AST) -> ast.AST: - """ - Transform the value in the list comprehension into a form usable in a for-loop. - - :param value_node: The value node to transform. - :return: The transformed value node. - """ - return value_node - - def replace_node(self, old_node: ast.AST, new_node: ast.AST): - """ - Replace an old node in the AST with a new node. - - :param old_node: The node to replace. - :param new_node: The node to insert in its place. - """ - parent = self.find_parent(old_node) - if parent: - for index, child in enumerate(ast.iter_child_nodes(parent)): - if child is old_node: - parent.body[index] = new_node - break - - def find_parent(self, node: ast.AST) -> ast.AST: - """ - Find the parent node of a given AST node. - - :param node: The node to find the parent for. - :return: The parent node, or None if not found. - """ - for parent in ast.walk(node): - for child in ast.iter_child_nodes(parent): - if child is node: - return parent - return None - - def ast_to_code(self, tree: ast.AST) -> str: - """ - Convert AST back to source code. - - :param tree: The AST to convert. - :return: The source code as a string. - """ - return astor.to_source(tree) diff --git a/src/refactorer/large_class_refactorer.py b/src/refactorer/large_class_refactorer.py deleted file mode 100644 index aff1f32d..00000000 --- a/src/refactorer/large_class_refactorer.py +++ /dev/null @@ -1,83 +0,0 @@ -import ast - -class LargeClassRefactorer: - """ - Refactorer for large classes that have too many methods. - """ - - def __init__(self, code: str, method_threshold: int = 5): - """ - Initializes the refactorer. - - :param code: The source code of the class to refactor. - :param method_threshold: The number of methods above which a class is considered large. - """ - self.code = code - self.method_threshold = method_threshold - - def refactor(self): - """ - Refactor the class by splitting it into smaller classes if it exceeds the method threshold. - - :return: The refactored code. - """ - # Parse the code to get the class definition - tree = ast.parse(self.code) - class_definitions = [node for node in tree.body if isinstance(node, ast.ClassDef)] - - refactored_code = [] - - for class_def in class_definitions: - methods = [n for n in class_def.body if isinstance(n, ast.FunctionDef)] - if len(methods) > self.method_threshold: - # If the class is large, split it - new_classes = self.split_class(class_def, methods) - refactored_code.extend(new_classes) - else: - # Keep the class as is - refactored_code.append(class_def) - - # Convert the AST back to code - return self.ast_to_code(refactored_code) - - def split_class(self, class_def, methods): - """ - Split the large class into smaller classes based on methods. - - :param class_def: The class definition node. - :param methods: The list of methods in the class. - :return: A list of new class definitions. - """ - # For demonstration, we'll simply create two classes based on the method count - half_index = len(methods) // 2 - new_class1 = self.create_new_class(class_def.name + "Part1", methods[:half_index]) - new_class2 = self.create_new_class(class_def.name + "Part2", methods[half_index:]) - - return [new_class1, new_class2] - - def create_new_class(self, new_class_name, methods): - """ - Create a new class definition with the specified methods. - - :param new_class_name: Name of the new class. - :param methods: List of methods to include in the new class. - :return: A new class definition node. - """ - # Create the class definition with methods - class_def = ast.ClassDef( - name=new_class_name, - bases=[], - body=methods, - decorator_list=[] - ) - return class_def - - def ast_to_code(self, nodes): - """ - Convert AST nodes back to source code. - - :param nodes: The AST nodes to convert. - :return: The source code as a string. - """ - import astor - return astor.to_source(nodes) diff --git a/src/refactorer/long_method_refactorer.py b/src/refactorer/long_method_refactorer.py deleted file mode 100644 index 459a32e4..00000000 --- a/src/refactorer/long_method_refactorer.py +++ /dev/null @@ -1,14 +0,0 @@ -from .base_refactorer import BaseRefactorer - -class LongMethodRefactorer(BaseRefactorer): - """ - Refactorer that targets long methods to improve readability. - """ - - def refactor(self): - """ - Refactor long methods into smaller methods. - Implement the logic to detect and refactor long methods. - """ - # Logic to identify long methods goes here - pass diff --git a/src/testing/test_runner.py b/src/testing/test_runner.py deleted file mode 100644 index 84fe92a9..00000000 --- a/src/testing/test_runner.py +++ /dev/null @@ -1,17 +0,0 @@ -import unittest -import os -import sys - -# Add the src directory to the path to import modules -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../src'))) - -# Discover and run all tests in the 'tests' directory -def run_tests(): - test_loader = unittest.TestLoader() - test_suite = test_loader.discover('tests', pattern='*.py') - - test_runner = unittest.TextTestRunner(verbosity=2) - test_runner.run(test_suite) - -if __name__ == '__main__': - run_tests() diff --git a/src/testing/test_validator.py b/src/testing/test_validator.py deleted file mode 100644 index cbbb29d4..00000000 --- a/src/testing/test_validator.py +++ /dev/null @@ -1,3 +0,0 @@ -def validate_output(original, refactored): - # Compare original and refactored output - return original == refactored diff --git a/src/utils/logger.py b/src/utils/logger.py deleted file mode 100644 index 711c62b5..00000000 --- a/src/utils/logger.py +++ /dev/null @@ -1,34 +0,0 @@ -import logging -import os - -def setup_logger(log_file: str = "app.log", log_level: int = logging.INFO): - """ - Set up the logger configuration. - - Args: - log_file (str): The name of the log file to write logs to. - log_level (int): The logging level (default is INFO). - - Returns: - Logger: Configured logger instance. - """ - # Create log directory if it does not exist - log_directory = os.path.dirname(log_file) - if log_directory and not os.path.exists(log_directory): - os.makedirs(log_directory) - - # Configure the logger - logging.basicConfig( - filename=log_file, - filemode='a', # Append mode - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - level=log_level, - ) - - logger = logging.getLogger(__name__) - return logger - -# # Example usage -# if __name__ == "__main__": -# logger = setup_logger() # You can customize the log file and level here -# logger.info("Logger is set up and ready to use.") diff --git a/test/test_analyzer.py b/test/test_analyzer.py deleted file mode 100644 index 3f522dd4..00000000 --- a/test/test_analyzer.py +++ /dev/null @@ -1,12 +0,0 @@ -# import unittest -# from src.analyzer.pylint_analyzer import PylintAnalyzer - -# class TestPylintAnalyzer(unittest.TestCase): -# def test_analyze_method(self): -# analyzer = PylintAnalyzer("path/to/test/code.py") -# report = analyzer.analyze() -# self.assertIsInstance(report, list) # Check if the output is a list -# # Add more assertions based on expected output - -# if __name__ == "__main__": -# unittest.main() diff --git a/test/test_end_to_end.py b/test/test_end_to_end.py deleted file mode 100644 index bef67b8e..00000000 --- a/test/test_end_to_end.py +++ /dev/null @@ -1,16 +0,0 @@ -import unittest - -class TestEndToEnd(unittest.TestCase): - """ - End-to-end tests for the full refactoring flow. - """ - - def test_refactor_flow(self): - """ - Test the complete flow from analysis to refactoring. - """ - # Implement the test logic here - self.assertTrue(True) # Placeholder for actual test - -if __name__ == "__main__": - unittest.main() diff --git a/test/test_energy_measure.py b/test/test_energy_measure.py deleted file mode 100644 index 00d381c6..00000000 --- a/test/test_energy_measure.py +++ /dev/null @@ -1,20 +0,0 @@ -import unittest -from src.measurement.energy_meter import EnergyMeter - -class TestEnergyMeter(unittest.TestCase): - """ - Unit tests for the EnergyMeter class. - """ - - def test_measurement(self): - """ - Test starting and stopping energy measurement. - """ - meter = EnergyMeter() - meter.start_measurement() - # Logic to execute code - result = meter.stop_measurement() - self.assertIsNotNone(result) # Check that a result is produced - -if __name__ == "__main__": - unittest.main() diff --git a/test/test_refactorer.py b/test/test_refactorer.py deleted file mode 100644 index af992428..00000000 --- a/test/test_refactorer.py +++ /dev/null @@ -1,99 +0,0 @@ -import unittest -from src.refactorer.long_method_refactorer import LongMethodRefactorer -from src.refactorer.large_class_refactorer import LargeClassRefactorer -from src.refactorer.complex_list_comprehension_refactorer import ComplexListComprehensionRefactorer - -class TestRefactorers(unittest.TestCase): - """ - Unit tests for various refactorers. - """ - - def test_refactor_long_method(self): - """ - Test the refactor method of the LongMethodRefactorer. - """ - original_code = """ - def long_method(): - # A long method with too many lines of code - a = 1 - b = 2 - c = a + b - # More complex logic... - return c - """ - expected_refactored_code = """ - def long_method(): - result = calculate_result() - return result - - def calculate_result(): - a = 1 - b = 2 - return a + b - """ - refactorer = LongMethodRefactorer(original_code) - result = refactorer.refactor() - self.assertEqual(result.strip(), expected_refactored_code.strip()) - - def test_refactor_large_class(self): - """ - Test the refactor method of the LargeClassRefactorer. - """ - original_code = """ - class LargeClass: - def method1(self): - # Method 1 - pass - - def method2(self): - # Method 2 - pass - - def method3(self): - # Method 3 - pass - - # ... many more methods ... - """ - expected_refactored_code = """ - class LargeClass: - def method1(self): - # Method 1 - pass - - class AnotherClass: - def method2(self): - # Method 2 - pass - - def method3(self): - # Method 3 - pass - """ - refactorer = LargeClassRefactorer(original_code) - result = refactorer.refactor() - self.assertEqual(result.strip(), expected_refactored_code.strip()) - - def test_refactor_complex_list_comprehension(self): - """ - Test the refactor method of the ComplexListComprehensionRefactorer. - """ - original_code = """ - def complex_list(): - return [x**2 for x in range(10) if x % 2 == 0 and x > 3] - """ - expected_refactored_code = """ - def complex_list(): - result = [] - for x in range(10): - if x % 2 == 0 and x > 3: - result.append(x**2) - return result - """ - refactorer = ComplexListComprehensionRefactorer(original_code) - result = refactorer.refactor() - self.assertEqual(result.strip(), expected_refactored_code.strip()) - -# Run all tests in the module -if __name__ == "__main__": - unittest.main() diff --git a/test/README.md b/tests/README.md similarity index 100% rename from test/README.md rename to tests/README.md diff --git a/tests/_input_copies/test_2_copy.py b/tests/_input_copies/test_2_copy.py new file mode 100644 index 00000000..4d1f853d --- /dev/null +++ b/tests/_input_copies/test_2_copy.py @@ -0,0 +1,105 @@ +import datetime # unused import + + +class Temp: + + def __init__(self) -> None: + self.unused_class_attribute = True + self.a = 3 + + def temp_function(self): + unused_var = 3 + b = 4 + return self.a + b + + +# LC: Large Class with too many responsibilities +class DataProcessor: + def __init__(self, data): + self.data = data + self.processed_data = [] + + # LM: Long Method - this method does way too much + def process_all_data(self): + results = [] + for item in self.data: + try: + # LPL: Long Parameter List + result = self.complex_calculation( + item, True, False, "multiply", 10, 20, None, "end" + ) + results.append(result) + except ( + Exception + ) as e: # UEH: Unqualified Exception Handling, catching generic exceptions + print("An error occurred:", e) + + # LMC: Long Message Chain + print(self.data[0].upper().strip().replace(" ", "_").lower()) + + # LLF: Long Lambda Function + self.processed_data = list( + filter(lambda x: x != None and x != 0 and len(str(x)) > 1, results) + ) + + return self.processed_data + + # LBCL: Long Base Class List + + +class AdvancedProcessor(DataProcessor): + pass + + # LTCE: Long Ternary Conditional Expression + def check_data(self, item): + return ( + True if item > 10 else False if item < -10 else None if item == 0 else item + ) + + # Complex List Comprehension + def complex_comprehension(self): + # CLC: Complex List Comprehension + self.processed_data = [ + x**2 if x % 2 == 0 else x**3 + for x in range(1, 100) + if x % 5 == 0 and x != 50 and x > 3 + ] + + # Long Element Chain + def long_chain(self): + # LEC: Long Element Chain accessing deeply nested elements + try: + deep_value = self.data[0][1]["details"]["info"]["more_info"][2]["target"] + return deep_value + except KeyError: + return None + + # Long Scope Chaining (LSC) + def long_scope_chaining(self): + for a in range(10): + for b in range(10): + for c in range(10): + for d in range(10): + for e in range(10): + if a + b + c + d + e > 25: + return "Done" + + # LPL: Long Parameter List + def complex_calculation( + self, item, flag1, flag2, operation, threshold, max_value, option, final_stage + ): + if operation == "multiply": + result = item * threshold + elif operation == "add": + result = item + max_value + else: + result = item + return result + + +# Main method to execute the code +if __name__ == "__main__": + sample_data = [1, 2, 3, 4, 5] + processor = DataProcessor(sample_data) + processed = processor.process_all_data() + print("Processed Data:", processed) diff --git a/tests/analyzers/test_long_element_chain_analyzer.py b/tests/analyzers/test_long_element_chain_analyzer.py new file mode 100644 index 00000000..d6d63cb5 --- /dev/null +++ b/tests/analyzers/test_long_element_chain_analyzer.py @@ -0,0 +1,300 @@ +import ast +from pathlib import Path +import textwrap +import pytest + +from ecooptimizer.analyzers.ast_analyzers.detect_long_element_chain import detect_long_element_chain +from ecooptimizer.data_types.smell import LECSmell + + +@pytest.fixture +def temp_file(tmp_path): + """Create a temporary file for testing.""" + file_path = tmp_path / "test_code.py" + return file_path + + +def parse_code(code_str): + """Parse code string into an AST.""" + return ast.parse(code_str) + + +def test_no_chains(temp_file): + """Test with code that has no chains.""" + code = textwrap.dedent(""" + a = 1 + b = 2 + c = a + b + d = {'key': 'value'} + e = d['key'] + """) + + with Path.open(temp_file, "w") as f: + f.write(code) + + tree = parse_code(code) + result = detect_long_element_chain(temp_file, tree) + + assert len(result) == 0 + + +def test_chains_below_threshold(temp_file): + """Test with chains shorter than threshold.""" + code = textwrap.dedent(""" + a = {'key1': {'key2': 'value'}} + b = a['key1']['key2'] + """) + + with Path.open(temp_file, "w") as f: + f.write(code) + + tree = parse_code(code) + # Using threshold of 5 + result = detect_long_element_chain(temp_file, tree, 5) + + assert len(result) == 0 + + +def test_chains_at_threshold(temp_file): + """Test with chains exactly at threshold.""" + code = textwrap.dedent(""" + a = {'key1': {'key2': {'key3': 'value'}}} + b = a['key1']['key2']['key3'] + """) + + with Path.open(temp_file, "w") as f: + f.write(code) + + tree = parse_code(code) + # Using threshold of 3 + result = detect_long_element_chain(temp_file, tree, 3) + + assert len(result) == 1 + assert result[0].messageId == "LEC001" + assert result[0].symbol == "long-element-chain" + assert result[0].occurences[0].line == 3 # Line 3 in the code + + +def test_chains_above_threshold(temp_file): + """Test with chains longer than threshold.""" + code = textwrap.dedent(""" + data = {'a': {'b': {'c': {'d': 'value'}}}} + result = data['a']['b']['c']['d'] + """) + + with Path.open(temp_file, "w") as f: + f.write(code) + + tree = parse_code(code) + # Using threshold of 3 + result = detect_long_element_chain(temp_file, tree, 3) + + assert len(result) == 1 + assert "Dictionary chain too long (4/3)" in result[0].message + + +def test_multiple_chains(temp_file): + """Test with multiple chains in the same file.""" + code = textwrap.dedent(""" + data1 = {'a': {'b': {'c': 'value1'}}} + data2 = {'x': {'y': {'z': 'value2'}}} + + result1 = data1['a']['b']['c'] + result2 = data2['x']['y']['z'] + + # Some other code without chains + a = 1 + b = 2 + """) + + with Path.open(temp_file, "w") as f: + f.write(code) + + tree = parse_code(code) + result = detect_long_element_chain(temp_file, tree, 3) + + assert len(result) == 2 + assert result[0].occurences[0].line != result[1].occurences[0].line + + +def test_nested_functions_with_chains(temp_file): + """Test chains inside nested functions and classes.""" + code = textwrap.dedent(""" + def outer_function(): + data = {'a': {'b': {'c': 'value'}}} + + def inner_function(): + return data['a']['b']['c'] + + return inner_function() + + class TestClass: + def method(self): + obj = {'x': {'y': {'z': {'deep': 'nested'}}}} + return obj['x']['y']['z']['deep'] + """) + + with Path.open(temp_file, "w") as f: + f.write(code) + + tree = parse_code(code) + result = detect_long_element_chain(temp_file, tree, 3) + + assert len(result) == 2 + # Check that we detected the chain in both locations + + +def test_same_line_reported_once(temp_file): + """Test that chains on the same line are reported only once.""" + code = textwrap.dedent(""" + data = {'a': {'b': {'c': 'value1'}}} + # Two identical chains on the same line + result1, result2 = data['a']['b']['c'], data['a']['b']['c'] + """) + + with Path.open(temp_file, "w") as f: + f.write(code) + + tree = parse_code(code) + result = detect_long_element_chain(temp_file, tree, 2) + + assert len(result) == 1 + + assert result[0].occurences[0].line == 4 + + +def test_variable_types_chains(temp_file): + """Test chains with different variable types.""" + code = textwrap.dedent(""" + # List within dict chain + data1 = {'a': [{'b': {'c': 'value'}}]} + result1 = data1['a'][0]['b']['c'] + + # Tuple with dict chain + data2 = {'x': ({'y': {'z': 'value'}},)} + result2 = data2['x'][0]['y']['z'] + """) + + with Path.open(temp_file, "w") as f: + f.write(code) + + tree = parse_code(code) + result = detect_long_element_chain(temp_file, tree, 3) + + assert len(result) == 2 + + +def test_custom_threshold(temp_file): + """Test with a custom threshold value.""" + code = textwrap.dedent(""" + data = {'a': {'b': {'c': 'value'}}} + result = data['a']['b']['c'] + """) + + with Path.open(temp_file, "w") as f: + f.write(code) + + tree = parse_code(code) + + # With threshold of 4, no chains should be detected + result1 = detect_long_element_chain(temp_file, tree, 4) + assert len(result1) == 0 + + # With threshold of 2, the chain should be detected + result2 = detect_long_element_chain(temp_file, tree, 2) + assert len(result2) == 1 + assert "Dictionary chain too long (3/2)" in result2[0].message + + +def test_result_structure(temp_file): + """Test the structure of the returned LECSmell object.""" + code = textwrap.dedent(""" + data = {'a': {'b': {'c': 'value'}}} + result = data['a']['b']['c'] + """) + + with Path.open(temp_file, "w") as f: + f.write(code) + + tree = parse_code(code) + result = detect_long_element_chain(temp_file, tree, 3) + + assert len(result) == 1 + smell = result[0] + + # Verify it's the correct type + assert isinstance(smell, LECSmell) + + # Check required fields + assert smell.path == str(temp_file) + assert smell.module == temp_file.stem + assert smell.type == "convention" + assert smell.symbol == "long-element-chain" + assert "Dictionary chain too long" in smell.message + + # Check occurrence details + assert len(smell.occurences) == 1 + assert smell.occurences[0].line == 3 + assert smell.occurences[0].column is not None + assert smell.occurences[0].endLine is not None + assert smell.occurences[0].endColumn is not None + + # Verify additional info exists + assert hasattr(smell, "additionalInfo") + + +def test_complex_expressions(temp_file): + """Test chains within complex expressions.""" + code = textwrap.dedent(""" + data = {'a': {'b': {'c': 5}}} + + # Chain in an arithmetic expression + result1 = data['a']['b']['c'] + 10 + + # Chain in a function call + def my_func(x): + return x * 2 + + result2 = my_func(data['a']['b']['c']) + + # Chain in a comprehension + result3 = [i * data['a']['b']['c'] for i in range(5)] + """) + + with Path.open(temp_file, "w") as f: + f.write(code) + + tree = parse_code(code) + result = detect_long_element_chain(temp_file, tree, 3) + + assert len(result) == 3 # Should detect all three chains + + +def test_edge_case_empty_file(temp_file): + """Test with an empty file.""" + code = "" + + with Path.open(temp_file, "w") as f: + f.write(code) + + tree = parse_code(code) + result = detect_long_element_chain(temp_file, tree) + + assert len(result) == 0 + + +def test_edge_case_threshold_one(temp_file): + """Test with threshold of 1 (every subscript would be reported).""" + code = textwrap.dedent(""" + data1 = {'a': [{'b': {'c': {'d': 'value'}}}]} + result1 = data1['a'][0]['b']['c']['d'] + """) + + with Path.open(temp_file, "w") as f: + f.write(code) + + tree = parse_code(code) + result = detect_long_element_chain(temp_file, tree, 5) + + assert len(result) == 1 + assert "Dictionary chain too long (5/5)" in result[0].message diff --git a/tests/analyzers/test_long_lambda_element_analyzer.py b/tests/analyzers/test_long_lambda_element_analyzer.py new file mode 100644 index 00000000..e25e91f1 --- /dev/null +++ b/tests/analyzers/test_long_lambda_element_analyzer.py @@ -0,0 +1,177 @@ +import ast +import textwrap +from pathlib import Path +from unittest.mock import patch + +from ecooptimizer.data_types.smell import LLESmell +from ecooptimizer.analyzers.ast_analyzers.detect_long_lambda_expression import ( + detect_long_lambda_expression, +) + +def test_no_lambdas(): + """Ensures no smells are detected when no lambda is present.""" + code = textwrap.dedent( + """ + def example(): + x = 42 + return x + 1 + """ + ) + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_lambda_expression(Path("fake.py"), ast.parse(code)) + assert len(smells) == 0 + + +def test_short_single_lambda(): + """ + A single short lambda (well under length=100) + and only one expression -> should NOT be flagged. + """ + code = textwrap.dedent( + """ + def example(): + f = lambda x: x + 1 + return f(5) + """ + ) + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_lambda_expression( + Path("fake.py"), + ast.parse(code), + ) + assert len(smells) == 0 + + +def test_lambda_exceeds_expr_count(): + """ + Long lambda due to too many expressions + In the AST, this breaks down as: + (x + 1 if x > 0 else 0) -> ast.IfExp (expression #1) + abs(x) * 2 -> ast.BinOp (Call inside it) (expression #2) + min(x, 5) -> ast.Call (expression #3) + """ + code = textwrap.dedent( + """ + def example(): + func = lambda x: (x + 1 if x > 0 else 0) + (x * 2 if x < 5 else 5) + abs(x) + return func(4) + """ + ) + + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_lambda_expression( + Path("fake.py"), + ast.parse(code), + ) + assert len(smells) == 1, "Expected smell due to expression count" + assert isinstance(smells[0], LLESmell) + + +def test_lambda_exceeds_char_length(): + """ + Exceeds threshold_length=100 by using a very long expression in the lambda. + """ + long_str = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" * 4 + code = textwrap.dedent( + f""" + def example(): + func = lambda x: x + "{long_str}" + return func("test") + """ + ) + # exceeds 100 char + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_lambda_expression( + Path("fake.py"), + ast.parse(code), + ) + assert len(smells) == 1, "Expected smell due to character length" + assert isinstance(smells[0], LLESmell) + + +def test_lambda_exceeds_both_thresholds(): + """ + Both too many chars and too many expressions + """ + code = textwrap.dedent( + """ + def example(): + giant_lambda = lambda a, b, c: (a + b if a > b else b - c) + (max(a, b, c) * 10) + (min(a, b, c) / 2) + ("hello" + "world") + return giant_lambda(1,2,3) + """ + ) + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_lambda_expression( + Path("fake.py"), + ast.parse(code), + ) + # one smell per line + assert len(smells) >= 1 + assert all(isinstance(smell, LLESmell) for smell in smells) + + +def test_lambda_nested(): + """ + Nested lambdas inside one function. + # outer and inner detected + """ + code = textwrap.dedent( + """ + def example(): + outer = lambda x: (x ** 2) + (lambda y: y + 10)(x) + # inner = lambda y: y + 10 is short, but let's make it long + # We'll artificially make it a big expression + inner = lambda a, b: (a + b if a > 0 else 0) + (a * b) + (b - a) + return outer(5) + inner(3,4) + """ + ) + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_lambda_expression( + Path("fake.py"), ast.parse(code), threshold_length=80, threshold_count=3 + ) + # inner and outter + assert len(smells) == 2 + assert isinstance(smells[0], LLESmell) + + +def test_lambda_inline_passed_to_function(): + """ + Lambdas passed inline to a function: sum(map(...)) or filter(..., lambda). + """ + code = textwrap.dedent( + """ + def test_lambdas(): + result = map(lambda x: x*2 + (x//3) if x > 10 else x, range(20)) + + # This lambda has a ternary, but let's keep it short enough + # that it doesn't trigger by default unless threshold_count=2 or so. + # We'll push it with a second ternary + more code to reach threshold_count=3 + + result2 = filter(lambda z: (z+1 if z < 5 else z-1) + (z*3 if z%2==0 else z/2) and z != 0, result) + + return list(result2) + """ + ) + + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_lambda_expression(Path("fake.py"), ast.parse(code)) + # 2 smells + assert len(smells) == 2 + assert all(isinstance(smell, LLESmell) for smell in smells) + + +def test_lambda_no_body_too_short(): + """ + A degenerate case: a lambda that has no real body or is trivially short. + Should produce 0 smells even if it's spread out. + """ + code = textwrap.dedent( + """ + def example(): + trivial = lambda: None + return trivial() + """ + ) + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_lambda_expression(Path("fake.py"), ast.parse(code)) + assert len(smells) == 0 diff --git a/tests/analyzers/test_long_message_chain_analyzer.py b/tests/analyzers/test_long_message_chain_analyzer.py new file mode 100644 index 00000000..52326c4e --- /dev/null +++ b/tests/analyzers/test_long_message_chain_analyzer.py @@ -0,0 +1,352 @@ +import ast +import textwrap +from pathlib import Path +from unittest.mock import patch + +from ecooptimizer.data_types.smell import LMCSmell +from ecooptimizer.analyzers.ast_analyzers.detect_long_message_chain import ( + detect_long_message_chain, +) + +# NOTE: The default threshold is 5. That means a chain of 5 or more consecutive calls will be flagged. + + +def test_detects_exact_five_calls_chain(): + """Detects a chain with exactly five method calls.""" + code = textwrap.dedent( + """ + def example(): + details = "some text" + details.upper().lower().capitalize().replace("|", "-").strip() + """ + ) + + # This chain has 5 calls: upper -> lower -> capitalize -> replace -> strip + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_message_chain(Path("fake.py"), ast.parse(code)) + + assert len(smells) == 1, "Expected exactly one smell for a chain of length 5" + assert isinstance(smells[0], LMCSmell) + assert "Method chain too long" in smells[0].message + assert smells[0].occurences[0].line == 4 + + +def test_detects_six_calls_chain(): + """Detects a chain with six method calls, definitely flagged.""" + code = textwrap.dedent( + """ + def example(): + details = "some text" + details.upper().lower().upper().capitalize().upper().replace("|", "-") + """ + ) + + # This chain has 6 calls: upper -> lower -> upper -> capitalize -> upper -> replace + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_message_chain(Path("fake.py"), ast.parse(code)) + + assert len(smells) == 1, "Expected exactly one smell for a chain of length 6" + assert isinstance(smells[0], LMCSmell) + assert "Method chain too long" in smells[0].message + assert smells[0].occurences[0].line == 4 + + +def test_ignores_chain_of_four_calls(): + """Ensures a chain with only four calls is NOT flagged (below threshold).""" + code = textwrap.dedent( + """ + def example(): + text = "some-other" + text.strip().lower().replace("-", "_").title() + """ + ) + + # This chain has 4 calls: strip -> lower -> replace -> title + # The default threshold is 5, so it should not be detected. + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_message_chain(Path("fake.py"), ast.parse(code)) + + assert len(smells) == 0, "Chain of length 4 should NOT be flagged" + + +def test_detects_chain_with_attributes_and_calls(): + """Detects a long chain that involves both attribute and method calls.""" + code = textwrap.dedent( + """ + class Sample: + def __init__(self): + self.details = "some text".upper() + def method(self): + # below is a chain with 5 steps: + # self.details -> lower() -> capitalize() -> isalpha() -> bit_length() + # isalpha() returns bool, bit_length() is from int => means chain length is still counted. + return self.details.upper().lower().capitalize().isalpha().bit_length() + """ + ) + + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_message_chain(Path("fake.py"), ast.parse(code)) + + # Because we have 5 method calls, it should be flagged. + assert len(smells) == 1, "Expected one smell for chain of length >= 5" + assert isinstance(smells[0], LMCSmell) + + +def test_detects_chain_inside_loop(): + """Detects a chain inside a loop that meets the threshold.""" + code = textwrap.dedent( + """ + def loop_chain(data_list): + for item in data_list: + item.strip().replace("-", "_").split("_").index("some") + """ + ) + + # Calls: strip -> replace -> split -> index = 4 calls total. + # add to 5 + code = code.replace('index("some")', 'index("some").upper()') + + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_message_chain(Path("fake.py"), ast.parse(code)) + + assert len(smells) == 1, "Expected smell for chain length 5" + assert isinstance(smells[0], LMCSmell) + + +def test_multiple_chains_one_line(): + """Detect multiple separate long chains on the same line. Should only report 1 smell, the first chain""" + code = textwrap.dedent( + """ + def combo(): + details = "some text" + other = "other text" + details.lower().title().replace("|", "-").upper().split("-"); other.upper().lower().capitalize().zfill(10).replace("xyz", "abc") + """ + ) + + # On line 5, we have two separate chains: + # 1) details -> lower -> title -> replace -> upper -> split => 5 calls. + # 2) other -> upper -> lower -> capitalize -> zfill -> replace => 5 calls. + + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_message_chain(Path("fake.py"), ast.parse(code)) + + # The function logic says it only reports once per line. So we expect 1 smell, not 2. + assert len(smells) == 1, "Both chains on the same line => single smell reported" + assert "Method chain too long" in smells[0].message + + +def test_ignores_separate_statements(): + """Ensures that separate statements with fewer calls each are not combined into one chain.""" + code = textwrap.dedent( + """ + def example(): + details = "some-other" + data = details.upper() + data = data.lower() + data = data.capitalize() + data = data.replace("|", "-") + data = data.title() + """ + ) + + # Each statement individually has only 1 call. + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_message_chain(Path("fake.py"), ast.parse(code)) + + assert len(smells) == 0, "No single chain of length >= 5 in separate statements" + + +def test_ignores_short_chain_comprehension(): + """Ensures short chain in a comprehension doesn't get flagged.""" + code = textwrap.dedent( + """ + def short_comp(lst): + return [item.replace("-", "_").lower() for item in lst] + """ + ) + + # Only 2 calls in the chain: replace -> lower. + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_message_chain(Path("fake.py"), ast.parse(code)) + + assert len(smells) == 0 + + +def test_detects_long_chain_comprehension(): + """Detects a long chain in a list comprehension.""" + code = textwrap.dedent( + """ + def long_comp(lst): + return [item.upper().lower().capitalize().strip().replace("|", "-") for item in lst] + """ + ) + + # 5 calls in the chain: upper -> lower -> capitalize -> strip -> replace. + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_message_chain(Path("fake.py"), ast.parse(code)) + + assert len(smells) == 1, "Expected one smell for chain of length 5" + assert isinstance(smells[0], LMCSmell) + + +def test_five_separate_long_chains(): + """ + Five distinct lines in a single function, each with a chain of exactly 5 calls. + Expect 5 separate smells (assuming you record each line). + """ + code = textwrap.dedent( + """ + def combo(): + data = "text" + data.upper().lower().capitalize().replace("|", "-").split("|") + data.capitalize().replace("|", "-").strip().upper().title() + data.lower().upper().replace("|", "-").strip().title() + data.strip().replace("|", "_").split("_").capitalize().title() + data.replace("|", "-").upper().lower().capitalize().title() + """ + ) + + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_message_chain(Path("fake.py"), ast.parse(code)) + + assert len(smells) == 5, "Expected 5 smells" + assert isinstance(smells[0], LMCSmell) + + +def test_element_access_chain_no_calls(): + """ + A chain of attributes and index lookups only, no parentheses (no actual calls). + Some detectors won't flag this unless they specifically count attribute hops. + """ + code = textwrap.dedent( + """ + def get_nested(nested): + return nested.a.b.c[3][0].x.y + """ + ) + + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_message_chain(Path("fake.py"), ast.parse(code)) + + assert len(smells) == 0, "Expected 0 smells" + + +def test_chain_with_slicing(): + """ + Demonstrates slicing as part of the chain. + e.g. `text[2:7]` -> `.replace()` -> `.upper()` ... + """ + code = textwrap.dedent( + """ + def slice_chain(text): + return text[2:7].replace("abc", "xyz").upper().strip().split("-").lower() + """ + ) + + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_message_chain(Path("fake.py"), ast.parse(code)) + + assert len(smells) == 1, "Expected 1 smells" + + +def test_multiline_chain(): + """ + A chain split over multiple lines using parentheses or backslash. + The AST should still see them as a continuous chain of calls. + """ + code = textwrap.dedent( + """ + def multiline_chain(): + var = "some text"\\ + .replace(" ", "-")\\ + .lower()\\ + .title()\\ + .strip()\\ + .upper() + """ + ) + + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_message_chain(Path("fake.py"), ast.parse(code)) + + assert len(smells) == 1, "Expected 1 smells" + + +def test_chain_in_lambda(): + """ + A chain inside a lambda's body. + """ + code = textwrap.dedent( + """ + def lambda_test(): + func = lambda x: x.upper().strip().replace("-", "_").lower().title() + return func("HELLO-WORLD") + """ + ) + # That’s 5 calls: upper -> strip -> replace -> lower -> title + # Expect 1 chain smell if you're scanning inside lambda bodies. + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_message_chain(Path("fake.py"), ast.parse(code)) + + assert len(smells) == 1, "Expected 1 smells" + + +def test_mixed_return_types_chain(): + """ + It's 5 calls, with type changes from str to bool to int. + Typical 'chain detection' doesn't care about type. + """ + code = textwrap.dedent( + """ + class TypeMix: + def do_stuff(self): + text = "Hello" + return text.lower().capitalize().isalpha().bit_length().to_bytes(2, 'big') + """ + ) + # That’s 5 calls: lower -> capitalize -> isalpha -> bit_length -> to_bytes + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_message_chain(Path("fake.py"), ast.parse(code)) + + assert len(smells) == 1, "Expected 1 smells" + + +def test_multiple_short_chains_same_line(): + """ + Two short chains on the same line, each with 3 calls, but they're separate. + They should not combine into 6, so likely 0 smells if threshold=5. + """ + code = textwrap.dedent( + """ + def short_line(): + x = "abc" + y = "def" + x.upper().replace("A", "Z").strip(); y.lower().replace("d", "x").title() + """ + ) + # Each chain is 3 calls, so if threshold is 5, expect 0 smells. + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_message_chain(Path("fake.py"), ast.parse(code)) + + assert len(smells) == 0, "Expected 0 smells" + + +def test_conditional_chain(): + """ + A chain inside an inline if/else expression (ternary). + The question: do we see it as a single chain? Usually yes, but only if we actually parse it as an ast.Call chain. + """ + code = textwrap.dedent( + """ + def cond_chain(cond): + text = "some text" + return (text.lower().replace(" ", "_").strip().upper() if cond + else text.upper().replace(" ", "|").lower().split("|")) + """ + ) + # code shouldnt lump them together + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_message_chain(Path("fake.py"), ast.parse(code)) + + assert len(smells) == 0, "Expected 0 smells" diff --git a/tests/analyzers/test_repeated_calls_analyzer.py b/tests/analyzers/test_repeated_calls_analyzer.py new file mode 100644 index 00000000..3d0e5acd --- /dev/null +++ b/tests/analyzers/test_repeated_calls_analyzer.py @@ -0,0 +1,132 @@ +import textwrap +from pathlib import Path +from ast import parse +from unittest.mock import patch +from ecooptimizer.data_types.smell import CRCSmell +from ecooptimizer.analyzers.ast_analyzers.detect_repeated_calls import ( + detect_repeated_calls, +) + + +def run_detection_test(code: str): + with patch.object(Path, "read_text", return_value=code): + return detect_repeated_calls(Path("fake.py"), parse(code)) + + +def test_detects_repeated_function_call(): + """Detects repeated function calls within the same scope.""" + code = textwrap.dedent(""" + def test_case(): + result1 = expensive_function(42) + result2 = expensive_function(42) + """) + smells = run_detection_test(code) + + assert len(smells) == 1 + assert isinstance(smells[0], CRCSmell) + assert len(smells[0].occurences) == 2 + assert smells[0].additionalInfo.callString == "expensive_function(42)" + + +def test_detects_repeated_method_call(): + """Detects repeated method calls on the same object instance.""" + code = textwrap.dedent(""" + class Demo: + def compute(self): + return 42 + def test_case(): + obj = Demo() + result1 = obj.compute() + result2 = obj.compute() + """) + smells = run_detection_test(code) + + assert len(smells) == 1 + assert isinstance(smells[0], CRCSmell) + assert len(smells[0].occurences) == 2 + assert smells[0].additionalInfo.callString == "obj.compute()" + + +def test_ignores_different_arguments(): + """Ensures repeated function calls with different arguments are NOT flagged.""" + code = textwrap.dedent(""" + def test_case(): + result1 = expensive_function(1) + result2 = expensive_function(2) + """) + smells = run_detection_test(code) + assert len(smells) == 0 + + +def test_ignores_modified_objects(): + """Ensures function calls on modified objects are NOT flagged.""" + code = textwrap.dedent(""" + class Demo: + def compute(self): + return self.value * 2 + def test_case(): + obj = Demo() + obj.value = 10 + result1 = obj.compute() + obj.value = 20 + result2 = obj.compute() + """) + smells = run_detection_test(code) + assert len(smells) == 0 + + +def test_detects_repeated_external_call(): + """Detects repeated external function calls (e.g., len(data.get("key"))).""" + code = textwrap.dedent(""" + def test_case(data): + result = len(data.get("key")) + repeated = len(data.get("key")) + """) + smells = run_detection_test(code) + + assert len(smells) == 1 + assert isinstance(smells[0], CRCSmell) + assert len(smells[0].occurences) == 2 + assert smells[0].additionalInfo.callString == 'len(data.get("key"))' + + +def test_detects_expensive_builtin_call(): + """Detects repeated calls to expensive built-in functions like max().""" + code = textwrap.dedent(""" + def test_case(data): + result1 = max(data) + result2 = max(data) + """) + smells = run_detection_test(code) + + assert len(smells) == 1 + assert isinstance(smells[0], CRCSmell) + assert len(smells[0].occurences) == 2 + assert smells[0].additionalInfo.callString == "max(data)" + + +def test_ignores_primitive_builtins(): + """Ensures built-in functions like abs() are NOT flagged when used with primitives.""" + code = textwrap.dedent(""" + def test_case(): + result1 = abs(-5) + result2 = abs(-5) + """) + smells = run_detection_test(code) + assert len(smells) == 0 + + +def test_detects_repeated_method_call_with_different_objects(): + """Ensures method calls on different objects are NOT flagged.""" + code = textwrap.dedent(""" + class Demo: + def compute(self): + return self.value * 2 + def test_case(): + obj1 = Demo() + obj2 = Demo() + result1 = obj1.compute() + result2 = obj2.compute() + """) + smells = run_detection_test(code) + assert len(smells) == 0 diff --git a/tests/analyzers/test_str_concat_analyzer.py b/tests/analyzers/test_str_concat_analyzer.py new file mode 100644 index 00000000..15b9f11d --- /dev/null +++ b/tests/analyzers/test_str_concat_analyzer.py @@ -0,0 +1,542 @@ +from pathlib import Path +from astroid import parse +from unittest.mock import patch + +from ecooptimizer.data_types.smell import SCLSmell +from ecooptimizer.analyzers.astroid_analyzers.detect_string_concat_in_loop import ( + detect_string_concat_in_loop, +) + +# === Basic Concatenation Cases === + + +def test_detects_simple_for_loop_concat(): + """Detects += string concatenation inside a for loop.""" + code = """ + def test(): + result = "" + for i in range(10): + result += str(i) + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 1 + assert isinstance(smells[0], SCLSmell) + + assert len(smells[0].occurences) == 1 + assert smells[0].additionalInfo.concatTarget == "result" + assert smells[0].additionalInfo.innerLoopLine == 4 + + +def test_detects_simple_assign_loop_concat(): + """Detects string concatenation inside a loop.""" + code = """ + def test(): + result = "" + for i in range(10): + result = result + str(i) + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 1 + assert isinstance(smells[0], SCLSmell) + + assert len(smells[0].occurences) == 1 + assert smells[0].additionalInfo.concatTarget == "result" + assert smells[0].additionalInfo.innerLoopLine == 4 + + +def test_detects_simple_while_loop_concat(): + """Detects += string concatenation inside a while loop.""" + code = """ + def test(): + result = "" + while i < 10: + result += str(i) + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 1 + assert isinstance(smells[0], SCLSmell) + + assert len(smells[0].occurences) == 1 + assert smells[0].additionalInfo.concatTarget == "result" + assert smells[0].additionalInfo.innerLoopLine == 4 + + +def test_detects_list_attribute_concat(): + """Detects += modifying a list item inside a loop.""" + code = """ + class Test: + def __init__(self): + self.text = [""] * 5 + def update(self): + for i in range(5): + self.text[0] += str(i) + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 1 + assert isinstance(smells[0], SCLSmell) + + assert len(smells[0].occurences) == 1 + assert smells[0].additionalInfo.concatTarget == "self.text[0]" + assert smells[0].additionalInfo.innerLoopLine == 6 + + +def test_detects_object_attribute_concat(): + """Detects += modifying an object attribute inside a loop.""" + code = """ + class Test: + def __init__(self): + self.text = "" + def update(self): + for i in range(5): + self.text += str(i) + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 1 + assert isinstance(smells[0], SCLSmell) + + assert len(smells[0].occurences) == 1 + assert smells[0].additionalInfo.concatTarget == "self.text" + assert smells[0].additionalInfo.innerLoopLine == 6 + + +def test_detects_dict_value_concat(): + """Detects += modifying a dictionary value inside a loop.""" + code = """ + def test(): + data = {"key": ""} + for i in range(5): + data["key"] += str(i) + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 1 + assert isinstance(smells[0], SCLSmell) + + assert len(smells[0].occurences) == 1 + # astroid changes double quotes to singles + assert smells[0].additionalInfo.concatTarget == "data['key']" + assert smells[0].additionalInfo.innerLoopLine == 4 + + +def test_detects_multi_loop_concat(): + """Detects multiple separate string concats in a loop.""" + code = """ + def test(): + result = "" + logs = [""] * 4 + for i in range(10): + result += str(i) + logs[0] += str(i) + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 2 + assert all(isinstance(smell, SCLSmell) for smell in smells) + + assert len(smells[0].occurences) == 1 + assert smells[0].additionalInfo.concatTarget == "result" + assert smells[0].additionalInfo.innerLoopLine == 5 + + assert len(smells[1].occurences) == 1 + assert smells[1].additionalInfo.concatTarget == "logs[0]" + assert smells[1].additionalInfo.innerLoopLine == 5 + + +def test_detects_reset_loop_concat(): + """Detects string concats with re-assignments inside the loop.""" + code = """ + def reset(): + result = '' + for i in range(5): + result += "Iteration: " + str(i) + if i == 2: + result = "" # Resetting `result` + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 1 + assert isinstance(smells[0], SCLSmell) + + assert len(smells[0].occurences) == 1 + assert smells[0].additionalInfo.concatTarget == "result" + assert smells[0].additionalInfo.innerLoopLine == 4 + + +# === Nested Loop Cases === + + +def test_detects_nested_loop_concat(): + """Detects concatenation inside nested loops.""" + code = """ + def test(): + result = "" + for i in range(3): + for j in range(3): + result += str(j) + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 1 + assert isinstance(smells[0], SCLSmell) + + assert len(smells[0].occurences) == 1 + assert smells[0].additionalInfo.concatTarget == "result" + assert smells[0].additionalInfo.innerLoopLine == 5 + + +def test_detects_complex_nested_loop_concat(): + """Detects multi level concatenations belonging to the same smell.""" + code = """ + def super_complex(): + result = '' + for i in range(5): + result += "Iteration: " + str(i) + for j in range(3): + result += "Nested: " + str(j) # Contributing to `result` + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 1 + assert isinstance(smells[0], SCLSmell) + + assert len(smells[0].occurences) == 2 + assert smells[0].additionalInfo.concatTarget == "result" + assert smells[0].additionalInfo.innerLoopLine == 4 + + +# === Conditional Cases === + + +def test_detects_if_else_concat(): + """Detects += inside an if-else condition within a loop.""" + code = """ + def test(): + result = "" + for i in range(5): + if i % 2 == 0: + result += "even" + else: + result += "odd" + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 1 + assert isinstance(smells[0], SCLSmell) + + assert len(smells[0].occurences) == 2 + assert smells[0].additionalInfo.concatTarget == "result" + assert smells[0].additionalInfo.innerLoopLine == 4 + + +# === String Interpolation Cases === + + +def test_detects_f_string_concat(): + """Detects += using f-strings inside a loop.""" + code = """ + def test(): + result = "" + for i in range(5): + result += f"{i}" + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 1 + assert isinstance(smells[0], SCLSmell) + + assert len(smells[0].occurences) == 1 + assert smells[0].additionalInfo.concatTarget == "result" + assert smells[0].additionalInfo.innerLoopLine == 4 + + +def test_detects_percent_format_concat(): + """Detects += using % formatting inside a loop.""" + code = """ + def test(): + result = "" + for i in range(5): + result += "%d" % i + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 1 + assert isinstance(smells[0], SCLSmell) + + assert len(smells[0].occurences) == 1 + assert smells[0].additionalInfo.concatTarget == "result" + assert smells[0].additionalInfo.innerLoopLine == 4 + + +def test_detects_str_format_concat(): + """Detects += using .format() inside a loop.""" + code = """ + def test(): + result = "" + for i in range(5): + result += "{}".format(i) + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 1 + assert isinstance(smells[0], SCLSmell) + + assert len(smells[0].occurences) == 1 + assert smells[0].additionalInfo.concatTarget == "result" + assert smells[0].additionalInfo.innerLoopLine == 4 + + +# === False Positives (Should NOT Detect) === + + +def test_ignores_access_inside_loop(): + """Ensures that accessing the concatenation variable inside the loop is NOT flagged.""" + code = """ + def test(): + result = "" + for i in range(5): + print(result) # Accessing result mid-loop + result += str(i) + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 0 + + +def test_ignores_regular_str_assign_inside_loop(): + """Ensures that regular string assignments are NOT flagged.""" + code = """ + def test(): + result = "" + for i in range(5): + result = str(i) + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 0 + + +def test_ignores_number_addition_inside_loop(): + """Ensures number operations with the += format are NOT flagged.""" + code = """ + def test(): + num = 1 + for i in range(5): + num += i + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 0 + + +def test_ignores_concat_outside_loop(): + """Ensures that string concatenation OUTSIDE a loop is NOT flagged.""" + code = """ + def test(): + result = "" + part1 = "Hello" + part2 = "World" + result = result + part1 + part2 + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 0 + + +# === Edge Cases === + + +def test_detects_sequential_concat(): + """Detects a variable concatenated multiple times in the same loop iteration.""" + code = """ + def test(): + result = "" + for i in range(5): + result += str(i) + result += "-" + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 1 + assert isinstance(smells[0], SCLSmell) + + assert len(smells[0].occurences) == 2 + assert smells[0].additionalInfo.concatTarget == "result" + assert smells[0].additionalInfo.innerLoopLine == 4 + + +def test_detects_concat_with_prefix_and_suffix(): + """Detects concatenation where both prefix and suffix are added.""" + code = """ + def test(): + result = "" + for i in range(5): + result = "prefix-" + result + "-suffix" + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 1 + assert isinstance(smells[0], SCLSmell) + + assert len(smells[0].occurences) == 1 + assert smells[0].additionalInfo.concatTarget == "result" + assert smells[0].additionalInfo.innerLoopLine == 4 + + +def test_detects_prepend_concat(): + """Detects += where new values are inserted at the beginning instead of the end.""" + code = """ + def test(): + result = "" + for i in range(5): + result = str(i) + result + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 1 + assert isinstance(smells[0], SCLSmell) + + assert len(smells[0].occurences) == 1 + assert smells[0].additionalInfo.concatTarget == "result" + assert smells[0].additionalInfo.innerLoopLine == 4 + + +# === Typing Cases === + + +def test_ignores_unknown_type(): + """Ignores potential smells where type cannot be confirmed as a string.""" + code = """ + def test(a, b): + result = a + for i in range(5): + result = result + b + + a = "Hello" + b = "world" + test(a) + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 0 + + +def test_detects_param_type_hint_concat(): + """Detects string concat where type is inferrred from the FunctionDef type hints.""" + code = """ + def test(a: str, b: str): + result = a + for i in range(5): + result = result + b + + a = "Hello" + b = "world" + test(a, b) + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 1 + assert isinstance(smells[0], SCLSmell) + + assert len(smells[0].occurences) == 1 + assert smells[0].additionalInfo.concatTarget == "result" + assert smells[0].additionalInfo.innerLoopLine == 4 + + +def test_detects_var_type_hint_concat(): + """Detects string concats where the type is inferred from an assign type hint.""" + code = """ + def test(a, b): + result: str = a + for i in range(5): + result = result + b + + a = "Hello" + b = "world" + test(a, b) + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 1 + assert isinstance(smells[0], SCLSmell) + + assert len(smells[0].occurences) == 1 + assert smells[0].additionalInfo.concatTarget == "result" + assert smells[0].additionalInfo.innerLoopLine == 4 + + +def test_detects_cls_attr_type_hint_concat(): + """Detects string concats where type is inferred from class attributes.""" + code = """ + class Test: + + def __init__(self): + self.text = "word" + + def test(self, a): + result = a + for i in range(5): + result = result + self.text + + a = Test() + a.test("this ") + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 1 + assert isinstance(smells[0], SCLSmell) + + assert len(smells[0].occurences) == 1 + assert smells[0].additionalInfo.concatTarget == "result" + assert smells[0].additionalInfo.innerLoopLine == 9 + + +def test_detects_inferred_str_type_concat(): + """Detects string concat where type is inferred from the initial value assigned.""" + code = """ + def test(a): + result = "" + for i in range(5): + result = a + result + + a = "hello" + test(a) + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 1 + assert isinstance(smells[0], SCLSmell) + + assert len(smells[0].occurences) == 1 + assert smells[0].additionalInfo.concatTarget == "result" + assert smells[0].additionalInfo.innerLoopLine == 4 diff --git a/tests/api/test_detect_route.py b/tests/api/test_detect_route.py new file mode 100644 index 00000000..150f94b9 --- /dev/null +++ b/tests/api/test_detect_route.py @@ -0,0 +1,81 @@ +from pathlib import Path +from fastapi.testclient import TestClient +from unittest.mock import patch + +from ecooptimizer.api.app import app +from ecooptimizer.data_types import Smell +from ecooptimizer.data_types.custom_fields import Occurence + +client = TestClient(app) + + +def get_mock_smell(): + return Smell( + confidence="UNKNOWN", + message="This is a message", + messageId="smellID", + module="module", + obj="obj", + path="fake_path.py", + symbol="smell-symbol", + type="type", + occurences=[ + Occurence( + line=9, + endLine=999, + column=999, + endColumn=999, + ) + ], + ) + + +def test_detect_smells_success(): + request_data = { + "file_path": "fake_path.py", + "enabled_smells": ["smell1", "smell2"], + } + + with patch("pathlib.Path.exists", return_value=True): + with patch( + "ecooptimizer.analyzers.analyzer_controller.AnalyzerController.run_analysis" + ) as mock_run_analysis: + mock_run_analysis.return_value = [get_mock_smell(), get_mock_smell()] + + response = client.post("/smells", json=request_data) + + assert response.status_code == 200 + assert len(response.json()) == 2 + + +def test_detect_smells_file_not_found(): + request_data = { + "file_path": "path/to/nonexistent/file.py", + "enabled_smells": ["smell1", "smell2"], + } + + response = client.post("/smells", json=request_data) + + assert response.status_code == 404 + assert ( + response.json()["detail"] + == f"File not found: {Path('path','to','nonexistent','file.py')!s}" + ) + + +def test_detect_smells_internal_server_error(): + request_data = { + "file_path": "fake_path.py", + "enabled_smells": ["smell1", "smell2"], + } + + with patch("pathlib.Path.exists", return_value=True): + with patch( + "ecooptimizer.analyzers.analyzer_controller.AnalyzerController.run_analysis" + ) as mock_run_analysis: + mock_run_analysis.side_effect = Exception("Internal error") + + response = client.post("/smells", json=request_data) + + assert response.status_code == 500 + assert response.json()["detail"] == "Internal server error" diff --git a/tests/api/test_refactor_route.py b/tests/api/test_refactor_route.py new file mode 100644 index 00000000..79a81155 --- /dev/null +++ b/tests/api/test_refactor_route.py @@ -0,0 +1,157 @@ +# ruff: noqa: PT004 +import pytest + +import shutil +from pathlib import Path +from typing import Any +from collections.abc import Generator +from fastapi.testclient import TestClient +from unittest.mock import patch + + +from ecooptimizer.api.app import app +from ecooptimizer.analyzers.analyzer_controller import AnalyzerController +from ecooptimizer.refactorers.refactorer_controller import RefactorerController + + +client = TestClient(app) + +SAMPLE_SMELL = { + "confidence": "UNKNOWN", + "message": "This is a message", + "messageId": "smellID", + "module": "module", + "obj": "obj", + "path": "fake_path.py", + "symbol": "smell-symbol", + "type": "type", + "occurences": [ + { + "line": 9, + "endLine": 999, + "column": 999, + "endColumn": 999, + } + ], +} + +SAMPLE_SOURCE_DIR = "path\\to\\source_dir" + + +@pytest.fixture(scope="module") +def mock_dependencies() -> Generator[None, Any, None]: + """Fixture to mock all dependencies for the /refactor route.""" + with ( + patch.object(Path, "is_dir"), + patch.object(shutil, "copytree"), + patch.object(shutil, "rmtree"), + patch.object( + RefactorerController, + "run_refactorer", + return_value=[ + Path("path/to/modified_file_1.py"), + Path("path/to/modified_file_2.py"), + ], + ), + patch.object(AnalyzerController, "run_analysis", return_value=[SAMPLE_SMELL]), + patch("tempfile.mkdtemp", return_value="/fake/temp/dir"), + ): + yield + + +def test_refactor_success(mock_dependencies): # noqa: ARG001 + """Test the /refactor route with a successful refactoring process.""" + Path.is_dir.return_value = True # type: ignore + + with patch("ecooptimizer.api.routes.refactor_smell.measure_energy", side_effect=[10.0, 5.0]): + request_data = { + "source_dir": SAMPLE_SOURCE_DIR, + "smell": SAMPLE_SMELL, + } + + response = client.post("/refactor", json=request_data) + + assert response.status_code == 200 + assert "refactoredData" in response.json() + assert "updatedSmells" in response.json() + assert len(response.json()["updatedSmells"]) == 1 + + +def test_refactor_source_dir_not_found(mock_dependencies): # noqa: ARG001 + """Test the /refactor route when the source directory does not exist.""" + Path.is_dir.return_value = False # type: ignore + + request_data = { + "source_dir": SAMPLE_SOURCE_DIR, + "smell": SAMPLE_SMELL, + } + + response = client.post("/refactor", json=request_data) + + assert response.status_code == 404 + assert f"Directory {SAMPLE_SOURCE_DIR} does not exist" in response.json()["detail"] + + +def test_refactor_energy_not_saved(mock_dependencies): # noqa: ARG001 + """Test the /refactor route when no energy is saved after refactoring.""" + Path.is_dir.return_value = True # type: ignore + + with patch("ecooptimizer.api.routes.refactor_smell.measure_energy", side_effect=[10.0, 15.0]): + request_data = { + "source_dir": SAMPLE_SOURCE_DIR, + "smell": SAMPLE_SMELL, + } + + response = client.post("/refactor", json=request_data) + + assert response.status_code == 400 + assert "Energy was not saved" in response.json()["detail"] + + +def test_refactor_initial_energy_not_retrieved(mock_dependencies): # noqa: ARG001 + """Test the /refactor route when no energy is saved after refactoring.""" + Path.is_dir.return_value = True # type: ignore + + with patch("ecooptimizer.api.routes.refactor_smell.measure_energy", return_value=None): + request_data = { + "source_dir": SAMPLE_SOURCE_DIR, + "smell": SAMPLE_SMELL, + } + + response = client.post("/refactor", json=request_data) + + assert response.status_code == 400 + assert "Could not retrieve initial emissions" in response.json()["detail"] + + +def test_refactor_final_energy_not_retrieved(mock_dependencies): # noqa: ARG001 + """Test the /refactor route when no energy is saved after refactoring.""" + Path.is_dir.return_value = True # type: ignore + + with patch("ecooptimizer.api.routes.refactor_smell.measure_energy", side_effect=[10.0, None]): + request_data = { + "source_dir": SAMPLE_SOURCE_DIR, + "smell": SAMPLE_SMELL, + } + + response = client.post("/refactor", json=request_data) + + assert response.status_code == 400 + assert "Could not retrieve final emissions" in response.json()["detail"] + + +def test_refactor_unexpected_error(mock_dependencies): # noqa: ARG001 + """Test the /refactor route when an unexpected error occurs during refactoring.""" + Path.is_dir.return_value = True # type: ignore + RefactorerController.run_refactorer.side_effect = Exception("Mock error") # type: ignore + + with patch("ecooptimizer.api.routes.refactor_smell.measure_energy", return_value=10.0): + request_data = { + "source_dir": SAMPLE_SOURCE_DIR, + "smell": SAMPLE_SMELL, + } + + response = client.post("/refactor", json=request_data) + + assert response.status_code == 400 + assert "Mock error" in response.json()["detail"] diff --git a/tests/benchmarking/__init__.py b/tests/benchmarking/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/benchmarking/benchmark.py b/tests/benchmarking/benchmark.py new file mode 100644 index 00000000..fa2f8941 --- /dev/null +++ b/tests/benchmarking/benchmark.py @@ -0,0 +1,201 @@ +# python benchmark.py /path/to/source_file.py + +#!/usr/bin/env python3 +""" +Benchmarking script for ecooptimizer. +This script benchmarks: + 1) Detection/analyzer runtime (via AnalyzerController.run_analysis) + 2) Refactoring runtime (via RefactorerController.run_refactorer) + 3) Energy measurement time (via CodeCarbonEnergyMeter.measure_energy) + +For each detected smell (grouped by smell type), refactoring is run multiple times to compute average times. +Usage: python benchmark.py +""" + +# import sys +# import os + +# # Add the src directory to the Python path +# sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../src"))) + +import time +import statistics +import json +import logging +import shutil +from pathlib import Path +from tempfile import TemporaryDirectory + +# Import controllers and energy measurement module +from ecooptimizer.analyzers.analyzer_controller import AnalyzerController +from ecooptimizer.refactorers.refactorer_controller import RefactorerController +from ecooptimizer.measurements.codecarbon_energy_meter import CodeCarbonEnergyMeter + +TEST_DIR = Path(__file__).parent.resolve() +OUTPUT_DIR = TEST_DIR / "output" +OUTPUT_DIR.mkdir(exist_ok=True) + +# Set up logging configuration +# logging.basicConfig(level=logging.INFO) +# logger = logging.getLogger("benchmark") + +# Create a logger +logger = logging.getLogger("benchmark") + +# Set the global logging level +logger.setLevel(logging.INFO) + +# Create a console handler +console_handler = logging.StreamHandler() +console_handler.setLevel(logging.INFO) # You can adjust the level for the console if needed + +# Create a file handler +log_file = OUTPUT_DIR / "benchmark_log.txt" +file_handler = logging.FileHandler(log_file, mode="w") +file_handler.setLevel(logging.INFO) # You can adjust the level for the file if needed + +# Create a formatter +formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") +console_handler.setFormatter(formatter) +file_handler.setFormatter(formatter) + +# Add both handlers to the logger +logger.addHandler(console_handler) +logger.addHandler(file_handler) + + +def benchmark_detection(source_path: str, iterations: int = 10): + """ + Benchmarks the detection phase. + Runs analyzer_controller.run_analysis multiple times on the given source file, + records the runtime for each iteration, and returns the average detection time. + Also returns the smells data from the final iteration. + """ + analyzer_controller = AnalyzerController() + detection_times = [] + smells_data = None + for i in range(iterations): + start = time.perf_counter() + # Run the analysis; this call detects all smells in the source file. + smells_data = analyzer_controller.run_analysis(Path(source_path)) + end = time.perf_counter() + elapsed = end - start + detection_times.append(elapsed) + logger.info(f"Detection iteration {i+1}/{iterations} took {elapsed:.6f} seconds") + avg_detection = statistics.mean(detection_times) + logger.info(f"Average detection time over {iterations} iterations: {avg_detection:.6f} seconds") + return smells_data, avg_detection + + +def benchmark_refactoring(smells_data, source_path: str, iterations: int = 10): + """ + Benchmarks the refactoring phase for each smell type. + For each smell in smells_data, runs refactoring (using refactorer_controller.run_refactorer) + repeatedly on a temporary copy of the source file. Also measures energy measurement time + (via energy_meter.measure_energy) after refactoring. + Returns two dictionaries: + - refactoring_stats: average refactoring time per smell type + - energy_stats: average energy measurement time per smell type + """ + refactorer_controller = RefactorerController() + energy_meter = CodeCarbonEnergyMeter() + refactoring_stats = {} # smell_type -> average refactoring time + energy_stats = {} # smell_type -> average energy measurement time + + # Group smells by type. (Assuming each smell has a 'messageId' attribute.) + grouped_smells = {} + for smell in smells_data: + smell_type = getattr(smell, "messageId", "unknown") + if smell_type not in grouped_smells: + grouped_smells[smell_type] = [] + grouped_smells[smell_type].append(smell) + + # For each smell type, benchmark refactoring and energy measurement times. + for smell_type, smell_list in grouped_smells.items(): + ref_times = [] + eng_times = [] + logger.info(f"Benchmarking refactoring for smell type: {smell_type}") + for smell in smell_list: + for i in range(iterations): + with TemporaryDirectory() as temp_dir: + # Create a temporary copy of the source file for refactoring. + temp_source = Path(temp_dir) / Path(source_path).name + shutil.copy(Path(source_path), temp_source) + + # Start timer for refactoring. + start_ref = time.perf_counter() + try: + _ = refactorer_controller.run_refactorer( + temp_source, Path(temp_dir), smell, overwrite=False + ) + except NotImplementedError as e: + logger.warning(f"Refactoring not implemented for smell: {e}") + continue + end_ref = time.perf_counter() + ref_time = end_ref - start_ref + ref_times.append(ref_time) + logger.info( + f"Refactoring iteration {i+1}/{iterations} for smell type '{smell_type}' took {ref_time:.6f} seconds" + ) + + # Measure energy measurement time immediately after refactoring. + start_eng = time.perf_counter() + energy_meter.measure_energy(temp_source) + end_eng = time.perf_counter() + eng_time = end_eng - start_eng + eng_times.append(eng_time) + logger.info( + f"Energy measurement iteration {i+1}/{iterations} for smell type '{smell_type}' took {eng_time:.6f} seconds" + ) + + # Compute average times for this smell type. + avg_ref_time = statistics.mean(ref_times) if ref_times else None + avg_eng_time = statistics.mean(eng_times) if eng_times else None + refactoring_stats[smell_type] = avg_ref_time + energy_stats[smell_type] = avg_eng_time + logger.info(f"Smell Type: {smell_type} - Average Refactoring Time: {avg_ref_time:.6f} sec") + logger.info( + f"Smell Type: {smell_type} - Average Energy Measurement Time: {avg_eng_time:.6f} sec" + ) + return refactoring_stats, energy_stats + + +def main(): + """ + Main benchmarking entry point. + Accepts the source file path as a command-line argument. + Runs detection and refactoring benchmarks, then logs and saves overall stats. + """ + # if len(sys.argv) < 2: + # print("Usage: python benchmark.py ") + # sys.exit(1) + + source_file_path = TEST_DIR / "test_code/250_sample.py" + + logger.info(f"Starting benchmark on source file: {source_file_path!s}") + + # Benchmark the detection phase. + smells_data, avg_detection = benchmark_detection(str(source_file_path)) + + # Benchmark the refactoring phase per smell type. + ref_stats, eng_stats = benchmark_refactoring(smells_data, str(source_file_path)) + + # Compile overall benchmark results. + overall_stats = { + "detection_average_time": avg_detection, + "refactoring_times": ref_stats, + "energy_measurement_times": eng_stats, + } + logger.info("Overall Benchmark Results:") + logger.info(json.dumps(overall_stats, indent=4)) + + output_file = OUTPUT_DIR / f"{source_file_path.stem}_benchmark_results.json" + + # Save benchmark results to a JSON file. + with open(output_file, "w") as outfile: # noqa: PTH123 + json.dump(overall_stats, outfile, indent=4) + logger.info(f"Benchmark results saved to {output_file!s}") + + +if __name__ == "__main__": + main() diff --git a/tests/benchmarking/test_code/1000_sample.py b/tests/benchmarking/test_code/1000_sample.py new file mode 100644 index 00000000..bb59ba9d --- /dev/null +++ b/tests/benchmarking/test_code/1000_sample.py @@ -0,0 +1,1322 @@ +""" +This module provides various mathematical helper functions. +It intentionally contains code smells for demonstration purposes. +""" + +import collections +import math + + +def long_element_chain(data): + """Access deeply nested elements repeatedly.""" + return data["level1"]["level2"]["level3"]["level4"]["level5"] + + +def long_lambda_function(): + """Creates an unnecessarily long lambda function.""" + return lambda x: (x**2 + 2 * x + 1) / (math.sqrt(x) + x**3 + x**4 + math.sin(x) + math.cos(x)) + + +def long_message_chain(obj): + """Access multiple chained attributes and methods.""" + return obj.get_first().get_second().get_third().get_fourth().get_fifth().value + + +def long_parameter_list(a, b, c, d, e, f, g, h, i, j): + """Function with too many parameters.""" + return (a + b) * (c - d) / (e + f) ** g - h * i + j + + +def member_ignoring_method(self): + """Method that does not use instance attributes.""" + return "I ignore all instance members!" + + +_cache = {} + + +def cached_expensive_call(x): + """Caches repeated calls to avoid redundant computations.""" + if x in _cache: + return _cache[x] + result = math.factorial(x) + math.sqrt(x) + math.log(x + 1) + _cache[x] = result + return result + + +def string_concatenation_in_loop(words): + """Bad practice: String concatenation inside a loop.""" + result = "" + for word in words: + result += word + ", " # Inefficient + return result.strip(", ") + + +# More functions to reach 250 lines with similar issues. +def complex_math_operation(a, b, c, d, e, f, g, h): + """Another long parameter list with a complex calculation.""" + return a**b + math.sqrt(c) - math.log(d) + e**f + g / h + + +def factorial_chain(x): + """Long element chain for factorial calculations.""" + return math.factorial(math.ceil(math.sqrt(math.fabs(x)))) + + +def inefficient_fibonacci(n): + """Recursively calculates Fibonacci inefficiently.""" + if n <= 1: + return n + return inefficient_fibonacci(n - 1) + inefficient_fibonacci(n - 2) + + +class MathHelper: + def __init__(self, value): + self.value = value + + def chained_operations(self): + """Demonstrates a long message chain.""" + return self.value.increment().double().square().cube().finalize() + + def ignore_member(self): + """This method does not use 'self' but exists in the class.""" + return "Completely ignores instance attributes!" + + +def expensive_function(x): + return x * x + + +def test_case(): + result1 = expensive_function(42) + result2 = expensive_function(42) + result3 = expensive_function(42) + return result1 + result2 + result3 + + +def long_loop_with_string_concatenation(n): + """Creates a long string inefficiently inside a loop.""" + result = "" + for i in range(n): + result += str(i) + " - " # Inefficient string building + return result.strip(" - ") + + +# More helper functions to reach 250 lines with similar bad practices. +def another_long_parameter_list(a, b, c, d, e, f, g, h, i): + """Another example of too many parameters.""" + return a * b + c / d - e**f + g - h + i + + +def contains_large_strings(strings): + return any([len(s) > 10 for s in strings]) + + +def do_god_knows_what(): + mystring = "i hate capstone" + n = 10 + + for i in range(n): + b = 10 + mystring += "word" + + return n + + +def do_something_dumb(): + return + + +class Solution: + def isSameTree(self, p, q): + return ( + p == q + if not p or not q + else p.val == q.val + and self.isSameTree(p.left, q.left) + and self.isSameTree(p.right, q.right) + ) + + +# Code Smell: Long Parameter List +class Vehicle: + def __init__( + self, + make, + model, + year: int, + color, + fuel_type, + engine_start_stop_option, + mileage, + suspension_setting, + transmission, + price, + seat_position_setting=None, + ): + # Code Smell: Long Parameter List in __init__ + self.make = make # positional argument + self.model = model + self.year = year + self.color = color + self.fuel_type = fuel_type + self.engine_start_stop_option = engine_start_stop_option + self.mileage = mileage + self.suspension_setting = suspension_setting + self.transmission = transmission + self.price = price + self.seat_position_setting = seat_position_setting # default value + self.owner = None # Unused class attribute, used in constructor + + def display_info(self): + # Code Smell: Long Message Chain + random_test = self.make.split("") + print( + f"Make: {self.make}, Model: {self.model}, Year: {self.year}".upper().replace(",", "")[ + ::2 + ] + ) + + def calculate_price(self): + # Code Smell: List Comprehension in an All Statement + condition = all( + [ + isinstance(attribute, str) + for attribute in [self.make, self.model, self.year, self.color] + ] + ) + if condition: + return ( + self.price * 0.9 + ) # Apply a 10% discount if all attributes are strings (totally arbitrary condition) + + return self.price + + def unused_method(self): + # Code Smell: Member Ignoring Method + print("This method doesn't interact with instance attributes, it just prints a statement.") + + +def longestArithSeqLength2(A: list[int]) -> int: + dp = collections.defaultdict(int) + for i in range(len(A)): + for j in range(i + 1, len(A)): + a, b = A[i], A[j] + dp[b - a, j] = max(dp[b - a, j], dp[b - a, i] + 1) + return max(dp.values()) + 1 + + +def longestArithSeqLength3(A: list[int]) -> int: + dp = collections.defaultdict(int) + for i in range(len(A)): + for j in range(i + 1, len(A)): + a, b = A[i], A[j] + dp[b - a, j] = max(dp[b - a, j], dp[b - a, i] + 1) + return max(dp.values()) + 1 + + +def longestArithSeqLength4(A: list[int]) -> int: + dp = collections.defaultdict(int) + for i in range(len(A)): + for j in range(i + 1, len(A)): + a, b = A[i], A[j] + dp[b - a, j] = max(dp[b - a, j], dp[b - a, i] + 1) + return max(dp.values()) + 1 + + +def longestArithSeqLength5(A: list[int]) -> int: + dp = collections.defaultdict(int) + for i in range(len(A)): + for j in range(i + 1, len(A)): + a, b = A[i], A[j] + dp[b - a, j] = max(dp[b - a, j], dp[b - a, i] + 1) + return max(dp.values()) + 1 + + +class Calculator: + def add(sum): + a = int(input("Enter number 1: ")) + b = int(input("Enter number 2: ")) + sum = a + b + print("The addition of two numbers:", sum) + + def mul(mul): + a = int(input("Enter number 1: ")) + b = int(input("Enter number 2: ")) + mul = a * b + print("The multiplication of two numbers:", mul) + + def sub(sub): + a = int(input("Enter number 1: ")) + b = int(input("Enter number 2: ")) + sub = a - b + print("The subtraction of two numbers:", sub) + + def div(div): + a = int(input("Enter number 1: ")) + b = int(input("Enter number 2: ")) + div = a / b + print("The division of two numbers: ", div) + + def exp(exp): + a = int(input("Enter number 1: ")) + b = int(input("Enter number 2: ")) + exp = a**b + print("The exponent of the following numbers are: ", exp) + + +class rootop: + def sqrt(self): + a = int(input("Enter number 1: ")) + b = int(input("Enter number 2: ")) + print(math.sqrt(a)) + print(math.sqrt(b)) + + def cbrt(self): + a = int(input("Enter number 1: ")) + b = int(input("Enter number 2: ")) + print(a ** (1 / 3)) + print(b ** (1 / 3)) + + def ranroot(self): + a = int(input("Enter the x: ")) + b = int(input("Enter the y: ")) + b_div = 1 / b + print("Your answer for the random root is: ", a**b_div) + + +import random +import string + + +def generate_random_string(length=10): + """Generate a random string of given length.""" + return "".join(random.choices(string.ascii_letters + string.digits, k=length)) + + +def add_numbers(a, b): + """Return the sum of two numbers.""" + return a + b + + +def multiply_numbers(a, b): + """Return the product of two numbers.""" + return a * b + + +def is_even(n): + """Check if a number is even.""" + return n % 2 == 0 + + +def factorial(n): + """Calculate the factorial of a number recursively.""" + return 1 if n == 0 else n * factorial(n - 1) + + +def reverse_string1(s): + """Reverse a given string.""" + return s[::-1] + + +def count_vowels1(s): + """Count the number of vowels in a string.""" + return sum(1 for char in s.lower() if char in "aeiou") + + +def find_max1(numbers): + """Find the maximum value in a list of numbers.""" + return max(numbers) if numbers else None + + +def shuffle_list1(lst): + """Shuffle a list randomly.""" + random.shuffle(lst) + return lst + + +def fibonacci1(n): + """Generate Fibonacci sequence up to the nth term.""" + sequence = [0, 1] + for _ in range(n - 2): + sequence.append(sequence[-1] + sequence[-2]) + return sequence[:n] + + +def is_palindrome1(s): + """Check if a string is a palindrome.""" + return s == s[::-1] + + +def remove_duplicates1(lst): + """Remove duplicates from a list.""" + return list(set(lst)) + + +def roll_dice(): + """Simulate rolling a six-sided dice.""" + return random.randint(1, 6) + + +def guess_number_game(): + """A simple number guessing game.""" + number = random.randint(1, 100) + attempts = 0 + print("Guess a number between 1 and 100!") + while True: + guess = int(input("Enter your guess: ")) + attempts += 1 + if guess < number: + print("Too low!") + elif guess > number: + print("Too high!") + else: + print(f"Correct! You guessed it in {attempts} attempts.") + break + + +def sort_numbers(lst): + """Sort a list of numbers.""" + return sorted(lst) + + +def merge_dicts(d1, d2): + """Merge two dictionaries.""" + return {**d1, **d2} + + +def get_random_element(lst): + """Get a random element from a list.""" + return random.choice(lst) if lst else None + + +def sum_list1(lst): + """Return the sum of elements in a list.""" + return sum(lst) + + +def countdown(n): + """Print a countdown from n to 0.""" + for i in range(n, -1, -1): + print(i) + + +def get_ascii_value(char): + """Return ASCII value of a character.""" + return ord(char) + + +def generate_random_password(length=12): + """Generate a random password.""" + chars = string.ascii_letters + string.digits + string.punctuation + return "".join(random.choice(chars) for _ in range(length)) + + +def find_common_elements(lst1, lst2): + """Find common elements between two lists.""" + return list(set(lst1) & set(lst2)) + + +def print_multiplication_table(n): + """Print multiplication table for a number.""" + for i in range(1, 11): + print(f"{n} x {i} = {n * i}") + + +def most_frequent_element(lst): + """Find the most frequent element in a list.""" + return max(set(lst), key=lst.count) if lst else None + + +def is_prime(n): + """Check if a number is prime.""" + if n < 2: + return False + for i in range(2, int(n**0.5) + 1): + if n % i == 0: + return False + return True + + +def convert_to_binary(n): + """Convert a number to binary.""" + return bin(n)[2:] + + +def sum_of_digits1(n): + """Find the sum of digits of a number.""" + return sum(int(digit) for digit in str(n)) + + +def matrix_transpose(matrix): + """Transpose a matrix.""" + return list(map(list, zip(*matrix))) + + +# Additional random functions to make it reach 200 lines +for _ in range(100): + + def temp_func(): + pass + + +# 1. Function to reverse a string +def reverse_string(s): + return s[::-1] + + +# 2. Function to check if a number is prime +def is_prime1(n): + return n > 1 and all(n % i != 0 for i in range(2, int(n**0.5) + 1)) + + +# 3. Function to calculate factorial +def factorial1(n): + return 1 if n <= 1 else n * factorial(n - 1) + + +# 4. Function to find the maximum number in a list +def find_max(lst): + return max(lst) + + +# 5. Function to count vowels in a string +def count_vowels(s): + return sum(1 for char in s if char.lower() in "aeiou") + + +# 6. Function to flatten a nested list +def flatten(lst): + return [item for sublist in lst for item in sublist] + + +# 7. Function to check if a string is a palindrome +def is_palindrome(s): + return s == s[::-1] + + +# 8. Function to generate Fibonacci sequence +def fibonacci(n): + return [0, 1] if n <= 1 else fibonacci(n - 1) + [fibonacci(n - 1)[-1] + fibonacci(n - 1)[-2]] + + +# 9. Function to calculate the area of a circle +def circle_area(r): + return 3.14159 * r**2 + + +# 10. Function to remove duplicates from a list +def remove_duplicates(lst): + return list(set(lst)) + + +# 11. Function to sort a dictionary by value +def sort_dict_by_value(d): + return dict(sorted(d.items(), key=lambda x: x[1])) + + +# 12. Function to count words in a string +def count_words(s): + return len(s.split()) + + +# 13. Function to check if two strings are anagrams +def are_anagrams(s1, s2): + return sorted(s1) == sorted(s2) + + +# 14. Function to find the intersection of two lists +def list_intersection(lst1, lst2): + return list(set(lst1) & set(lst2)) + + +# 15. Function to calculate the sum of digits of a number +def sum_of_digits2(n): + return sum(int(digit) for digit in str(n)) + + +# 16. Function to generate a random password +def generate_password(length=8): + return "".join(random.choice(string.ascii_letters + string.digits) for _ in range(length)) + + +# 21. Function to find the longest word in a string +def longest_word(s): + return max(s.split(), key=len) + + +# 22. Function to capitalize the first letter of each word +def capitalize_words(s): + return " ".join(word.capitalize() for word in s.split()) + + +# 23. Function to check if a year is a leap year +def is_leap_year(year): + return year % 4 == 0 and (year % 100 != 0 or year % 400 == 0) + + +# 24. Function to calculate the GCD of two numbers +def gcd1(a, b): + return a if b == 0 else gcd(b, a % b) + + +# 25. Function to calculate the LCM of two numbers +def lcm1(a, b): + return a * b // gcd(a, b) + + +# 26. Function to generate a list of squares +def squares(n): + return [i**2 for i in range(1, n + 1)] + + +# 27. Function to generate a list of cubes +def cubes(n): + return [i**3 for i in range(1, n + 1)] + + +# 28. Function to check if a list is sorted +def is_sorted(lst): + return all(lst[i] <= lst[i + 1] for i in range(len(lst) - 1)) + + +# 29. Function to shuffle a list +def shuffle_list(lst): + random.shuffle(lst) + return lst + + +# 30. Function to find the mode of a list +from collections import Counter + + +def find_mode(lst): + return Counter(lst).most_common(1)[0][0] + + +# 31. Function to calculate the mean of a list +def mean(lst): + return sum(lst) / len(lst) + + +# 32. Function to calculate the median of a list +def median(lst): + lst_sorted = sorted(lst) + mid = len(lst) // 2 + return (lst_sorted[mid] + lst_sorted[~mid]) / 2 + + +# 33. Function to calculate the standard deviation of a list +def std_dev(lst): + m = mean(lst) + return math.sqrt(sum((x - m) ** 2 for x in lst) / len(lst)) + + +# 34. Function to find the nth Fibonacci number +def nth_fibonacci(n): + return fibonacci(n)[-1] + + +# 35. Function to check if a number is even +def is_even1(n): + return n % 2 == 0 + + +# 36. Function to check if a number is odd +def is_odd(n): + return n % 2 != 0 + + +# 37. Function to convert Celsius to Fahrenheit +def celsius_to_fahrenheit(c): + return (c * 9 / 5) + 32 + + +# 38. Function to convert Fahrenheit to Celsius +def fahrenheit_to_celsius(f): + return (f - 32) * 5 / 9 + + +# 39. Function to calculate the hypotenuse of a right triangle +def hypotenuse(a, b): + return math.sqrt(a**2 + b**2) + + +# 40. Function to calculate the perimeter of a rectangle +def rectangle_perimeter(l, w): + return 2 * (l + w) + + +# 41. Function to calculate the area of a rectangle +def rectangle_area(l, w): + return l * w + + +# 42. Function to calculate the perimeter of a square +def square_perimeter(s): + return 4 * s + + +# 43. Function to calculate the area of a square +def square_area(s): + return s**2 + + +# 44. Function to calculate the perimeter of a circle +def circle_perimeter(r): + return 2 * 3.14159 * r + + +# 45. Function to calculate the volume of a cube +def cube_volume(s): + return s**3 + + +# 46. Function to calculate the volume of a sphere +def sphere_volume1(r): + return (4 / 3) * 3.14159 * r**3 + + +# 47. Function to calculate the volume of a cylinder +def cylinder_volume1(r, h): + return 3.14159 * r**2 * h + + +# 48. Function to calculate the volume of a cone +def cone_volume1(r, h): + return (1 / 3) * 3.14159 * r**2 * h + + +# 49. Function to calculate the surface area of a cube +def cube_surface_area(s): + return 6 * s**2 + + +# 50. Function to calculate the surface area of a sphere +def sphere_surface_area1(r): + return 4 * 3.14159 * r**2 + + +# 51. Function to calculate the surface area of a cylinder +def cylinder_surface_area1(r, h): + return 2 * 3.14159 * r * (r + h) + + +# 52. Function to calculate the surface area of a cone +def cone_surface_area1(r, l): + return 3.14159 * r * (r + l) + + +# 53. Function to generate a list of random numbers +def random_numbers(n, start=0, end=100): + return [random.randint(start, end) for _ in range(n)] + + +# 54. Function to find the index of an element in a list +def find_index(lst, element): + return lst.index(element) if element in lst else -1 + + +# 55. Function to remove an element from a list +def remove_element(lst, element): + return [x for x in lst if x != element] + + +# 56. Function to replace an element in a list +def replace_element(lst, old, new): + return [new if x == old else x for x in lst] + + +# 57. Function to rotate a list by n positions +def rotate_list(lst, n): + return lst[n:] + lst[:n] + + +# 58. Function to find the second largest number in a list +def second_largest(lst): + return sorted(lst)[-2] + + +# 59. Function to find the second smallest number in a list +def second_smallest(lst): + return sorted(lst)[1] + + +# 60. Function to check if all elements in a list are unique +def all_unique(lst): + return len(lst) == len(set(lst)) + + +# 61. Function to find the difference between two lists +def list_difference(lst1, lst2): + return list(set(lst1) - set(lst2)) + + +# 62. Function to find the union of two lists +def list_union(lst1, lst2): + return list(set(lst1) | set(lst2)) + + +# 63. Function to find the symmetric difference of two lists +def symmetric_difference(lst1, lst2): + return list(set(lst1) ^ set(lst2)) + + +# 64. Function to check if a list is a subset of another list +def is_subset(lst1, lst2): + return set(lst1).issubset(set(lst2)) + + +# 65. Function to check if a list is a superset of another list +def is_superset(lst1, lst2): + return set(lst1).issuperset(set(lst2)) + + +# 66. Function to find the frequency of elements in a list +def element_frequency(lst): + return {x: lst.count(x) for x in set(lst)} + + +# 67. Function to find the most frequent element in a list +def most_frequent(lst): + return max(set(lst), key=lst.count) + + +# 68. Function to find the least frequent element in a list +def least_frequent(lst): + return min(set(lst), key=lst.count) + + +# 69. Function to find the average of a list of numbers +def average(lst): + return sum(lst) / len(lst) + + +# 70. Function to find the sum of a list of numbers +def sum_list(lst): + return sum(lst) + + +# 71. Function to find the product of a list of numbers +def product_list(lst): + return math.prod(lst) + + +# 72. Function to find the cumulative sum of a list +def cumulative_sum(lst): + return [sum(lst[: i + 1]) for i in range(len(lst))] + + +# 73. Function to find the cumulative product of a list +def cumulative_product(lst): + return [math.prod(lst[: i + 1]) for i in range(len(lst))] + + +# 74. Function to find the difference between consecutive elements in a list +def consecutive_difference(lst): + return [lst[i + 1] - lst[i] for i in range(len(lst) - 1)] + + +# 75. Function to find the ratio between consecutive elements in a list +def consecutive_ratio(lst): + return [lst[i + 1] / lst[i] for i in range(len(lst) - 1)] + + +# 76. Function to find the cumulative difference of a list +def cumulative_difference(lst): + return [lst[0]] + [lst[i] - lst[i - 1] for i in range(1, len(lst))] + + +# 77. Function to find the cumulative ratio of a list +def cumulative_ratio(lst): + return [lst[0]] + [lst[i] / lst[i - 1] for i in range(1, len(lst))] + + +# 78. Function to find the absolute difference between two lists +def absolute_difference(lst1, lst2): + return [abs(lst1[i] - lst2[i]) for i in range(len(lst1))] + + +# 79. Function to find the absolute sum of two lists +def absolute_sum(lst1, lst2): + return [lst1[i] + lst2[i] for i in range(len(lst1))] + + +# 80. Function to find the absolute product of two lists +def absolute_product(lst1, lst2): + return [lst1[i] * lst2[i] for i in range(len(lst1))] + + +# 81. Function to find the absolute ratio of two lists +def absolute_ratio(lst1, lst2): + return [lst1[i] / lst2[i] for i in range(len(lst1))] + + +# 82. Function to find the absolute cumulative sum of two lists +def absolute_cumulative_sum(lst1, lst2): + return [sum(lst1[: i + 1]) + sum(lst2[: i + 1]) for i in range(len(lst1))] + + +# 83. Function to find the absolute cumulative product of two lists +def absolute_cumulative_product(lst1, lst2): + return [math.prod(lst1[: i + 1]) * math.prod(lst2[: i + 1]) for i in range(len(lst1))] + + +# 84. Function to find the absolute cumulative difference of two lists +def absolute_cumulative_difference(lst1, lst2): + return [sum(lst1[: i + 1]) - sum(lst2[: i + 1]) for i in range(len(lst1))] + + +# 85. Function to find the absolute cumulative ratio of two lists +def absolute_cumulative_ratio(lst1, lst2): + return [sum(lst1[: i + 1]) / sum(lst2[: i + 1]) for i in range(len(lst1))] + + +# 86. Function to find the absolute cumulative sum of a list +def absolute_cumulative_sum_single(lst): + return [sum(lst[: i + 1]) for i in range(len(lst))] + + +# 87. Function to find the absolute cumulative product of a list +def absolute_cumulative_product_single(lst): + return [math.prod(lst[: i + 1]) for i in range(len(lst))] + + +# 88. Function to find the absolute cumulative difference of a list +def absolute_cumulative_difference_single(lst): + return [sum(lst[: i + 1]) - sum(lst[:i]) for i in range(len(lst))] + + +# 89. Function to find the absolute cumulative ratio of a list +def absolute_cumulative_ratio_single(lst): + return [sum(lst[: i + 1]) / sum(lst[:i]) for i in range(len(lst))] + + +# 90. Function to find the absolute cumulative sum of a list with a constant +def absolute_cumulative_sum_constant(lst, constant): + return [sum(lst[: i + 1]) + constant for i in range(len(lst))] + + +# 91. Function to find the absolute cumulative product of a list with a constant +def absolute_cumulative_product_constant(lst, constant): + return [math.prod(lst[: i + 1]) * constant for i in range(len(lst))] + + +# 92. Function to find the absolute cumulative difference of a list with a constant +def absolute_cumulative_difference_constant(lst, constant): + return [sum(lst[: i + 1]) - constant for i in range(len(lst))] + + +# 93. Function to find the absolute cumulative ratio of a list with a constant +def absolute_cumulative_ratio_constant(lst, constant): + return [sum(lst[: i + 1]) / constant for i in range(len(lst))] + + +# 94. Function to find the absolute cumulative sum of a list with a list of constants +def absolute_cumulative_sum_constants(lst, constants): + return [sum(lst[: i + 1]) + constants[i] for i in range(len(lst))] + + +# 95. Function to find the absolute cumulative product of a list with a list of constants +def absolute_cumulative_product_constants(lst, constants): + return [math.prod(lst[: i + 1]) * constants[i] for i in range(len(lst))] + + +# 96. Function to find the absolute cumulative difference of a list with a list of constants +def absolute_cumulative_difference_constants(lst, constants): + return [sum(lst[: i + 1]) - constants[i] for i in range(len(lst))] + + +# 97. Function to find the absolute cumulative ratio of a list with a list of constants +def absolute_cumulative_ratio_constants(lst, constants): + return [sum(lst[: i + 1]) / constants[i] for i in range(len(lst))] + + +# 98. Function to find the absolute cumulative sum of a list with a function +def absolute_cumulative_sum_function(lst, func): + return [sum(lst[: i + 1]) + func(i) for i in range(len(lst))] + + +# 99. Function to find the absolute cumulative product of a list with a function +def absolute_cumulative_product_function(lst, func): + return [math.prod(lst[: i + 1]) * func(i) for i in range(len(lst))] + + +# 100. Function to find the absolute cumulative difference of a list with a function +def absolute_cumulative_difference_function(lst, func): + return [sum(lst[: i + 1]) - func(i) for i in range(len(lst))] + + +# 101. Function to find the absolute cumulative ratio of a list with a function +def absolute_cumulative_ratio_function(lst, func): + return [sum(lst[: i + 1]) / func(i) for i in range(len(lst))] + + +# 102. Function to find the absolute cumulative sum of a list with a lambda function +def absolute_cumulative_sum_lambda(lst, func): + return [sum(lst[: i + 1]) + func(i) for i in range(len(lst))] + + +# 103. Function to find the absolute cumulative product of a list with a lambda function +def absolute_cumulative_product_lambda(lst, func): + return [math.prod(lst[: i + 1]) * func(i) for i in range(len(lst))] + + +# 104. Function to find the absolute cumulative difference of a list with a lambda function +def absolute_cumulative_difference_lambda(lst, func): + return [sum(lst[: i + 1]) - func(i) for i in range(len(lst))] + + +# 105. Function to find the absolute cumulative ratio of a list with a lambda function +def absolute_cumulative_ratio_lambda(lst, func): + return [sum(lst[: i + 1]) / func(i) for i in range(len(lst))] + + +# 134. Function to check if a string is a valid email address +def is_valid_email(email): + import re + + pattern = r"^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$" + return bool(re.match(pattern, email)) + + +# 135. Function to generate a list of prime numbers up to a given limit +def generate_primes(limit): + primes = [] + for num in range(2, limit + 1): + if all(num % i != 0 for i in range(2, int(num**0.5) + 1)): + primes.append(num) + return primes + + +# 136. Function to calculate the nth Fibonacci number using recursion +def nth_fibonacci_recursive(n): + if n <= 0: + return 0 + elif n == 1: + return 1 + else: + return nth_fibonacci_recursive(n - 1) + nth_fibonacci_recursive(n - 2) + + +# 137. Function to calculate the nth Fibonacci number using iteration +def nth_fibonacci_iterative(n): + a, b = 0, 1 + for _ in range(n): + a, b = b, a + b + return a + + +# 138. Function to calculate the factorial of a number using iteration +def factorial_iterative(n): + result = 1 + for i in range(1, n + 1): + result *= i + return result + + +# 139. Function to calculate the factorial of a number using recursion +def factorial_recursive(n): + if n <= 1: + return 1 + else: + return n * factorial_recursive(n - 1) + + +# 140. Function to calculate the sum of all elements in a nested list +def sum_nested_list(lst): + total = 0 + for element in lst: + if isinstance(element, list): + total += sum_nested_list(element) + else: + total += element + return total + + +# 141. Function to flatten a nested list +def flatten_nested_list(lst): + flattened = [] + for element in lst: + if isinstance(element, list): + flattened.extend(flatten_nested_list(element)) + else: + flattened.append(element) + return flattened + + +# 142. Function to find the longest word in a string +def longest_word_in_string(s): + words = s.split() + longest = "" + for word in words: + if len(word) > len(longest): + longest = word + return longest + + +# 143. Function to count the frequency of each character in a string +def character_frequency(s): + frequency = {} + for char in s: + if char in frequency: + frequency[char] += 1 + else: + frequency[char] = 1 + return frequency + + +# 144. Function to check if a number is a perfect square +def is_perfect_square(n): + if n < 0: + return False + sqrt = int(n**0.5) + return sqrt * sqrt == n + + +# 145. Function to check if a number is a perfect cube +def is_perfect_cube(n): + if n < 0: + return False + cube_root = round(n ** (1 / 3)) + return cube_root**3 == n + + +# 146. Function to calculate the sum of squares of the first n natural numbers +def sum_of_squares(n): + return sum(i**2 for i in range(1, n + 1)) + + +# 147. Function to calculate the sum of cubes of the first n natural numbers +def sum_of_cubes(n): + return sum(i**3 for i in range(1, n + 1)) + + +# 148. Function to calculate the sum of the digits of a number +def sum_of_digits(n): + total = 0 + while n > 0: + total += n % 10 + n = n // 10 + return total + + +# 149. Function to calculate the product of the digits of a number +def product_of_digits(n): + product = 1 + while n > 0: + product *= n % 10 + n = n // 10 + return product + + +# 150. Function to reverse a number +def reverse_number(n): + reversed_num = 0 + while n > 0: + reversed_num = reversed_num * 10 + n % 10 + n = n // 10 + return reversed_num + + +# 151. Function to check if a number is a palindrome +def is_number_palindrome(n): + return n == reverse_number(n) + + +# 152. Function to generate a list of all divisors of a number +def divisors(n): + divisors = [] + for i in range(1, n + 1): + if n % i == 0: + divisors.append(i) + return divisors + + +# 153. Function to check if a number is abundant +def is_abundant(n): + return sum(divisors(n)) - n > n + + +# 154. Function to check if a number is deficient +def is_deficient(n): + return sum(divisors(n)) - n < n + + +# 155. Function to check if a number is perfect +def is_perfect(n): + return sum(divisors(n)) - n == n + + +# 156. Function to calculate the greatest common divisor (GCD) of two numbers +def gcd(a, b): + while b: + a, b = b, a % b + return a + + +# 157. Function to calculate the least common multiple (LCM) of two numbers +def lcm(a, b): + return a * b // gcd(a, b) + + +# 158. Function to generate a list of the first n triangular numbers +def triangular_numbers(n): + return [i * (i + 1) // 2 for i in range(1, n + 1)] + + +# 159. Function to generate a list of the first n square numbers +def square_numbers(n): + return [i**2 for i in range(1, n + 1)] + + +# 160. Function to generate a list of the first n cube numbers +def cube_numbers(n): + return [i**3 for i in range(1, n + 1)] + + +# 161. Function to calculate the area of a triangle given its base and height +def triangle_area(base, height): + return 0.5 * base * height + + +# 162. Function to calculate the area of a trapezoid given its bases and height +def trapezoid_area(base1, base2, height): + return 0.5 * (base1 + base2) * height + + +# 163. Function to calculate the area of a parallelogram given its base and height +def parallelogram_area(base, height): + return base * height + + +# 164. Function to calculate the area of a rhombus given its diagonals +def rhombus_area(diagonal1, diagonal2): + return 0.5 * diagonal1 * diagonal2 + + +# 165. Function to calculate the area of a regular polygon given the number of sides and side length +def regular_polygon_area(n, side_length): + import math + + return (n * side_length**2) / (4 * math.tan(math.pi / n)) + + +# 166. Function to calculate the perimeter of a regular polygon given the number of sides and side length +def regular_polygon_perimeter(n, side_length): + return n * side_length + + +# 167. Function to calculate the volume of a rectangular prism given its dimensions +def rectangular_prism_volume(length, width, height): + return length * width * height + + +# 168. Function to calculate the surface area of a rectangular prism given its dimensions +def rectangular_prism_surface_area(length, width, height): + return 2 * (length * width + width * height + height * length) + + +# 169. Function to calculate the volume of a pyramid given its base area and height +def pyramid_volume(base_area, height): + return (1 / 3) * base_area * height + + +# 170. Function to calculate the surface area of a pyramid given its base area and slant height +def pyramid_surface_area(base_area, slant_height): + return base_area + (1 / 2) * base_area * slant_height + + +# 171. Function to calculate the volume of a cone given its radius and height +def cone_volume(radius, height): + return (1 / 3) * 3.14159 * radius**2 * height + + +# 172. Function to calculate the surface area of a cone given its radius and slant height +def cone_surface_area(radius, slant_height): + return 3.14159 * radius * (radius + slant_height) + + +# 173. Function to calculate the volume of a sphere given its radius +def sphere_volume(radius): + return (4 / 3) * 3.14159 * radius**3 + + +# 174. Function to calculate the surface area of a sphere given its radius +def sphere_surface_area(radius): + return 4 * 3.14159 * radius**2 + + +# 175. Function to calculate the volume of a cylinder given its radius and height +def cylinder_volume(radius, height): + return 3.14159 * radius**2 * height + + +# 176. Function to calculate the surface area of a cylinder given its radius and height +def cylinder_surface_area(radius, height): + return 2 * 3.14159 * radius * (radius + height) + + +# 177. Function to calculate the volume of a torus given its major and minor radii +def torus_volume(major_radius, minor_radius): + return 2 * 3.14159**2 * major_radius * minor_radius**2 + + +# 178. Function to calculate the surface area of a torus given its major and minor radii +def torus_surface_area(major_radius, minor_radius): + return 4 * 3.14159**2 * major_radius * minor_radius + + +# 179. Function to calculate the volume of an ellipsoid given its semi-axes +def ellipsoid_volume(a, b, c): + return (4 / 3) * 3.14159 * a * b * c + + +# 180. Function to calculate the surface area of an ellipsoid given its semi-axes +def ellipsoid_surface_area(a, b, c): + # Approximation for surface area of an ellipsoid + p = 1.6075 + return 4 * 3.14159 * ((a**p * b**p + a**p * c**p + b**p * c**p) / 3) ** (1 / p) + + +# 181. Function to calculate the volume of a paraboloid given its radius and height +def paraboloid_volume(radius, height): + return (1 / 2) * 3.14159 * radius**2 * height + + +# 182. Function to calculate the surface area of a paraboloid given its radius and height +def paraboloid_surface_area(radius, height): + # Approximation for surface area of a paraboloid + return (3.14159 * radius / (6 * height**2)) * ( + (radius**2 + 4 * height**2) ** (3 / 2) - radius**3 + ) + + +# 183. Function to calculate the volume of a hyperboloid given its radii and height +def hyperboloid_volume(radius1, radius2, height): + return (1 / 3) * 3.14159 * height * (radius1**2 + radius1 * radius2 + radius2**2) + + +# 184. Function to calculate the surface area of a hyperboloid given its radii and height +def hyperboloid_surface_area(radius1, radius2, height): + # Approximation for surface area of a hyperboloid + return 3.14159 * (radius1 + radius2) * math.sqrt((radius1 - radius2) ** 2 + height**2) + + +# 185. Function to calculate the volume of a tetrahedron given its edge length +def tetrahedron_volume(edge_length): + return (edge_length**3) / (6 * math.sqrt(2)) + + +# 186. Function to calculate the surface area of a tetrahedron given its edge length +def tetrahedron_surface_area(edge_length): + return math.sqrt(3) * edge_length**2 + + +# 187. Function to calculate the volume of an octahedron given its edge length +def octahedron_volume(edge_length): + return (math.sqrt(2) / 3) * edge_length**3 + + +if __name__ == "__main__": + print("Math Helper Library Loaded") diff --git a/tests/benchmarking/test_code/250_sample.py b/tests/benchmarking/test_code/250_sample.py new file mode 100644 index 00000000..d549d726 --- /dev/null +++ b/tests/benchmarking/test_code/250_sample.py @@ -0,0 +1,219 @@ +""" +This module provides various mathematical helper functions. +It intentionally contains code smells for demonstration purposes. +""" + +import collections +import math + + +def long_element_chain(data): + """Access deeply nested elements repeatedly.""" + return data["level1"]["level2"]["level3"]["level4"]["level5"] + + +def long_lambda_function(): + """Creates an unnecessarily long lambda function.""" + return lambda x: (x**2 + 2 * x + 1) / (math.sqrt(x) + x**3 + x**4 + math.sin(x) + math.cos(x)) + + +def long_message_chain(obj): + """Access multiple chained attributes and methods.""" + return obj.get_first().get_second().get_third().get_fourth().get_fifth().value + + +def long_parameter_list(a, b, c, d, e, f, g, h, i, j): + """Function with too many parameters.""" + return (a + b) * (c - d) / (e + f) ** g - h * i + j + + +def member_ignoring_method(self): + """Method that does not use instance attributes.""" + return "I ignore all instance members!" + + +_cache = {} + + +def cached_expensive_call(x): + """Caches repeated calls to avoid redundant computations.""" + if x in _cache: + return _cache[x] + result = math.factorial(x) + math.sqrt(x) + math.log(x + 1) + _cache[x] = result + return result + + +def string_concatenation_in_loop(words): + """Bad practice: String concatenation inside a loop.""" + result = "" + for word in words: + result += word + ", " # Inefficient + return result.strip(", ") + + +# More functions to reach 250 lines with similar issues. +def complex_math_operation(a, b, c, d, e, f, g, h): + """Another long parameter list with a complex calculation.""" + return a**b + math.sqrt(c) - math.log(d) + e**f + g / h + + +def factorial_chain(x): + """Long element chain for factorial calculations.""" + return math.factorial(math.ceil(math.sqrt(math.fabs(x)))) + + +def inefficient_fibonacci(n): + """Recursively calculates Fibonacci inefficiently.""" + if n <= 1: + return n + return inefficient_fibonacci(n - 1) + inefficient_fibonacci(n - 2) + + +class MathHelper: + def __init__(self, value): + self.value = value + + def chained_operations(self): + """Demonstrates a long message chain.""" + return self.value.increment().double().square().cube().finalize() + + def ignore_member(self): + """This method does not use 'self' but exists in the class.""" + return "Completely ignores instance attributes!" + + +def expensive_function(x): + return x * x + + +def test_case(): + result1 = expensive_function(42) + result2 = expensive_function(42) + result3 = expensive_function(42) + return result1 + result2 + result3 + + +def long_loop_with_string_concatenation(n): + """Creates a long string inefficiently inside a loop.""" + result = "" + for i in range(n): + result += str(i) + " - " # Inefficient string building + return result.strip(" - ") + + +# More helper functions to reach 250 lines with similar bad practices. +def another_long_parameter_list(a, b, c, d, e, f, g, h, i): + """Another example of too many parameters.""" + return a * b + c / d - e**f + g - h + i + + +def contains_large_strings(strings): + return any([len(s) > 10 for s in strings]) + + +def do_god_knows_what(): + mystring = "i hate capstone" + n = 10 + + for i in range(n): + b = 10 + mystring += "word" + + return n + + +def do_something_dumb(): + return + + +class Solution: + def isSameTree(self, p, q): + return ( + p == q + if not p or not q + else p.val == q.val + and self.isSameTree(p.left, q.left) + and self.isSameTree(p.right, q.right) + ) + + +# Code Smell: Long Parameter List +class Vehicle: + def __init__( + self, + make, + model, + year: int, + color, + fuel_type, + engine_start_stop_option, + mileage, + suspension_setting, + transmission, + price, + seat_position_setting=None, + ): + # Code Smell: Long Parameter List in __init__ + self.make = make # positional argument + self.model = model + self.year = year + self.color = color + self.fuel_type = fuel_type + self.engine_start_stop_option = engine_start_stop_option + self.mileage = mileage + self.suspension_setting = suspension_setting + self.transmission = transmission + self.price = price + self.seat_position_setting = seat_position_setting # default value + self.owner = None # Unused class attribute, used in constructor + + def display_info(self): + # Code Smell: Long Message Chain + random_test = self.make.split("") + print( + f"Make: {self.make}, Model: {self.model}, Year: {self.year}".upper().replace(",", "")[ + ::2 + ] + ) + + def calculate_price(self): + # Code Smell: List Comprehension in an All Statement + condition = all( + [ + isinstance(attribute, str) + for attribute in [self.make, self.model, self.year, self.color] + ] + ) + if condition: + return ( + self.price * 0.9 + ) # Apply a 10% discount if all attributes are strings (totally arbitrary condition) + + return self.price + + def unused_method(self): + # Code Smell: Member Ignoring Method + print("This method doesn't interact with instance attributes, it just prints a statement.") + + +def longestArithSeqLength2(A: list[int]) -> int: + dp = collections.defaultdict(int) + for i in range(len(A)): + for j in range(i + 1, len(A)): + a, b = A[i], A[j] + dp[b - a, j] = max(dp[b - a, j], dp[b - a, i] + 1) + return max(dp.values()) + 1 + + +def longestArithSeqLength3(A: list[int]) -> int: + dp = collections.defaultdict(int) + for i in range(len(A)): + for j in range(i + 1, len(A)): + a, b = A[i], A[j] + dp[b - a, j] = max(dp[b - a, j], dp[b - a, i] + 1) + return max(dp.values()) + 1 + + +if __name__ == "__main__": + print("Math Helper Library Loaded") diff --git a/tests/benchmarking/test_code/3000_sample.py b/tests/benchmarking/test_code/3000_sample.py new file mode 100644 index 00000000..f8faab14 --- /dev/null +++ b/tests/benchmarking/test_code/3000_sample.py @@ -0,0 +1,3622 @@ +""" +This module provides various mathematical helper functions. +It intentionally contains code smells for demonstration purposes. +""" + +import collections +import math + + +def long_element_chain(data): + """Access deeply nested elements repeatedly.""" + return data["level1"]["level2"]["level3"]["level4"]["level5"] + + +def long_lambda_function(): + """Creates an unnecessarily long lambda function.""" + return lambda x: (x**2 + 2 * x + 1) / (math.sqrt(x) + x**3 + x**4 + math.sin(x) + math.cos(x)) + + +def long_message_chain(obj): + """Access multiple chained attributes and methods.""" + return obj.get_first().get_second().get_third().get_fourth().get_fifth().value + + +def long_parameter_list(a, b, c, d, e, f, g, h, i, j): + """Function with too many parameters.""" + return (a + b) * (c - d) / (e + f) ** g - h * i + j + + +def member_ignoring_method(self): + """Method that does not use instance attributes.""" + return "I ignore all instance members!" + + +_cache = {} + + +def cached_expensive_call(x): + """Caches repeated calls to avoid redundant computations.""" + if x in _cache: + return _cache[x] + result = math.factorial(x) + math.sqrt(x) + math.log(x + 1) + _cache[x] = result + return result + + +def string_concatenation_in_loop(words): + """Bad practice: String concatenation inside a loop.""" + result = "" + for word in words: + result += word + ", " # Inefficient + return result.strip(", ") + + +# More functions to reach 250 lines with similar issues. +def complex_math_operation(a, b, c, d, e, f, g, h): + """Another long parameter list with a complex calculation.""" + return a**b + math.sqrt(c) - math.log(d) + e**f + g / h + + +def factorial_chain(x): + """Long element chain for factorial calculations.""" + return math.factorial(math.ceil(math.sqrt(math.fabs(x)))) + + +def inefficient_fibonacci(n): + """Recursively calculates Fibonacci inefficiently.""" + if n <= 1: + return n + return inefficient_fibonacci(n - 1) + inefficient_fibonacci(n - 2) + + +class MathHelper: + def __init__(self, value): + self.value = value + + def chained_operations(self): + """Demonstrates a long message chain.""" + return self.value.increment().double().square().cube().finalize() + + def ignore_member(self): + """This method does not use 'self' but exists in the class.""" + return "Completely ignores instance attributes!" + + +def expensive_function(x): + return x * x + + +def test_case(): + result1 = expensive_function(42) + result2 = expensive_function(42) + result3 = expensive_function(42) + return result1 + result2 + result3 + + +def long_loop_with_string_concatenation(n): + """Creates a long string inefficiently inside a loop.""" + result = "" + for i in range(n): + result += str(i) + " - " # Inefficient string building + return result.strip(" - ") + + +# More helper functions to reach 250 lines with similar bad practices. +def another_long_parameter_list(a, b, c, d, e, f, g, h, i): + """Another example of too many parameters.""" + return a * b + c / d - e**f + g - h + i + + +def contains_large_strings(strings): + return any([len(s) > 10 for s in strings]) + + +def do_god_knows_what(): + mystring = "i hate capstone" + n = 10 + + for i in range(n): + b = 10 + mystring += "word" + + return n + + +def do_something_dumb(): + return + + +class Solution: + def isSameTree(self, p, q): + return ( + p == q + if not p or not q + else p.val == q.val + and self.isSameTree(p.left, q.left) + and self.isSameTree(p.right, q.right) + ) + + +# Code Smell: Long Parameter List +class Vehicle: + def __init__( + self, + make, + model, + year: int, + color, + fuel_type, + engine_start_stop_option, + mileage, + suspension_setting, + transmission, + price, + seat_position_setting=None, + ): + # Code Smell: Long Parameter List in __init__ + self.make = make # positional argument + self.model = model + self.year = year + self.color = color + self.fuel_type = fuel_type + self.engine_start_stop_option = engine_start_stop_option + self.mileage = mileage + self.suspension_setting = suspension_setting + self.transmission = transmission + self.price = price + self.seat_position_setting = seat_position_setting # default value + self.owner = None # Unused class attribute, used in constructor + + def display_info(self): + # Code Smell: Long Message Chain + random_test = self.make.split("") + print( + f"Make: {self.make}, Model: {self.model}, Year: {self.year}".upper().replace(",", "")[ + ::2 + ] + ) + + def calculate_price(self): + # Code Smell: List Comprehension in an All Statement + condition = all( + [ + isinstance(attribute, str) + for attribute in [self.make, self.model, self.year, self.color] + ] + ) + if condition: + return ( + self.price * 0.9 + ) # Apply a 10% discount if all attributes are strings (totally arbitrary condition) + + return self.price + + def unused_method(self): + # Code Smell: Member Ignoring Method + print("This method doesn't interact with instance attributes, it just prints a statement.") + + +def longestArithSeqLength2(A: list[int]) -> int: + dp = collections.defaultdict(int) + for i in range(len(A)): + for j in range(i + 1, len(A)): + a, b = A[i], A[j] + dp[b - a, j] = max(dp[b - a, j], dp[b - a, i] + 1) + return max(dp.values()) + 1 + + +def longestArithSeqLength3(A: list[int]) -> int: + dp = collections.defaultdict(int) + for i in range(len(A)): + for j in range(i + 1, len(A)): + a, b = A[i], A[j] + dp[b - a, j] = max(dp[b - a, j], dp[b - a, i] + 1) + return max(dp.values()) + 1 + + +def longestArithSeqLength4(A: list[int]) -> int: + dp = collections.defaultdict(int) + for i in range(len(A)): + for j in range(i + 1, len(A)): + a, b = A[i], A[j] + dp[b - a, j] = max(dp[b - a, j], dp[b - a, i] + 1) + return max(dp.values()) + 1 + + +def longestArithSeqLength5(A: list[int]) -> int: + dp = collections.defaultdict(int) + for i in range(len(A)): + for j in range(i + 1, len(A)): + a, b = A[i], A[j] + dp[b - a, j] = max(dp[b - a, j], dp[b - a, i] + 1) + return max(dp.values()) + 1 + + +class Calculator: + def add(sum): + a = int(input("Enter number 1: ")) + b = int(input("Enter number 2: ")) + sum = a + b + print("The addition of two numbers:", sum) + + def mul(mul): + a = int(input("Enter number 1: ")) + b = int(input("Enter number 2: ")) + mul = a * b + print("The multiplication of two numbers:", mul) + + def sub(sub): + a = int(input("Enter number 1: ")) + b = int(input("Enter number 2: ")) + sub = a - b + print("The subtraction of two numbers:", sub) + + def div(div): + a = int(input("Enter number 1: ")) + b = int(input("Enter number 2: ")) + div = a / b + print("The division of two numbers: ", div) + + def exp(exp): + a = int(input("Enter number 1: ")) + b = int(input("Enter number 2: ")) + exp = a**b + print("The exponent of the following numbers are: ", exp) + + +class rootop: + def sqrt(): + a = int(input("Enter number 1: ")) + b = int(input("Enter number 2: ")) + print(math.sqrt(a)) + print(math.sqrt(b)) + + def cbrt(): + a = int(input("Enter number 1: ")) + b = int(input("Enter number 2: ")) + print(a ** (1 / 3)) + print(b ** (1 / 3)) + + def ranroot(): + a = int(input("Enter the x: ")) + b = int(input("Enter the y: ")) + b_div = 1 / b + print("Your answer for the random root is: ", a**b_div) + + +import random +import string + + +def generate_random_string(length=10): + """Generate a random string of given length.""" + return "".join(random.choices(string.ascii_letters + string.digits, k=length)) + + +def add_numbers(a, b): + """Return the sum of two numbers.""" + return a + b + + +def multiply_numbers(a, b): + """Return the product of two numbers.""" + return a * b + + +def is_even1(n): + """Check if a number is even.""" + return n % 2 == 0 + + +def factorial1(n): + """Calculate the factorial of a number recursively.""" + return 1 if n == 0 else n * factorial(n - 1) + + +def reverse_string1(s): + """Reverse a given string.""" + return s[::-1] + + +def count_vowels1(s): + """Count the number of vowels in a string.""" + return sum(1 for char in s.lower() if char in "aeiou") + + +def find_max1(numbers): + """Find the maximum value in a list of numbers.""" + return max(numbers) if numbers else None + + +def shuffle_list1(lst): + """Shuffle a list randomly.""" + random.shuffle(lst) + return lst + + +def fibonacci1(n): + """Generate Fibonacci sequence up to the nth term.""" + sequence = [0, 1] + for _ in range(n - 2): + sequence.append(sequence[-1] + sequence[-2]) + return sequence[:n] + + +def is_palindrome1(s): + """Check if a string is a palindrome.""" + return s == s[::-1] + + +def remove_duplicates1(lst): + """Remove duplicates from a list.""" + return list(set(lst)) + + +def roll_dice(): + """Simulate rolling a six-sided dice.""" + return random.randint(1, 6) + + +def guess_number_game(): + """A simple number guessing game.""" + number = random.randint(1, 100) + attempts = 0 + print("Guess a number between 1 and 100!") + while True: + guess = int(input("Enter your guess: ")) + attempts += 1 + if guess < number: + print("Too low!") + elif guess > number: + print("Too high!") + else: + print(f"Correct! You guessed it in {attempts} attempts.") + break + + +def sort_numbers(lst): + """Sort a list of numbers.""" + return sorted(lst) + + +def merge_dicts(d1, d2): + """Merge two dictionaries.""" + return {**d1, **d2} + + +def get_random_element(lst): + """Get a random element from a list.""" + return random.choice(lst) if lst else None + + +def sum_list1(lst): + """Return the sum of elements in a list.""" + return sum(lst) + + +def countdown(n): + """Print a countdown from n to 0.""" + for i in range(n, -1, -1): + print(i) + + +def get_ascii_value(char): + """Return ASCII value of a character.""" + return ord(char) + + +def generate_random_password(length=12): + """Generate a random password.""" + chars = string.ascii_letters + string.digits + string.punctuation + return "".join(random.choice(chars) for _ in range(length)) + + +def find_common_elements(lst1, lst2): + """Find common elements between two lists.""" + return list(set(lst1) & set(lst2)) + + +def print_multiplication_table(n): + """Print multiplication table for a number.""" + for i in range(1, 11): + print(f"{n} x {i} = {n * i}") + + +def most_frequent_element(lst): + """Find the most frequent element in a list.""" + return max(set(lst), key=lst.count) if lst else None + + +def is_prime1(n): + """Check if a number is prime.""" + if n < 2: + return False + for i in range(2, int(n**0.5) + 1): + if n % i == 0: + return False + return True + + +def convert_to_binary(n): + """Convert a number to binary.""" + return bin(n)[2:] + + +def sum_of_digits2(n): + """Find the sum of digits of a number.""" + return sum(int(digit) for digit in str(n)) + + +def matrix_transpose(matrix): + """Transpose a matrix.""" + return list(map(list, zip(*matrix))) + + +# Additional random functions to make it reach 200 lines +for _ in range(100): + + def temp_func(): + pass + + +# 1. Function to reverse a string +def reverse_string(s): + return s[::-1] + + +# 2. Function to check if a number is prime +def is_prime(n): + return n > 1 and all(n % i != 0 for i in range(2, int(n**0.5) + 1)) + + +# 3. Function to calculate factorial +def factorial(n): + return 1 if n <= 1 else n * factorial(n - 1) + + +# 4. Function to find the maximum number in a list +def find_max(lst): + return max(lst) + + +# 5. Function to count vowels in a string +def count_vowels(s): + return sum(1 for char in s if char.lower() in "aeiou") + + +# 6. Function to flatten a nested list +def flatten(lst): + return [item for sublist in lst for item in sublist] + + +# 7. Function to check if a string is a palindrome +def is_palindrome(s): + return s == s[::-1] + + +# 8. Function to generate Fibonacci sequence +def fibonacci(n): + return [0, 1] if n <= 1 else fibonacci(n - 1) + [fibonacci(n - 1)[-1] + fibonacci(n - 1)[-2]] + + +# 9. Function to calculate the area of a circle +def circle_area(r): + return 3.14159 * r**2 + + +# 10. Function to remove duplicates from a list +def remove_duplicates(lst): + return list(set(lst)) + + +# 11. Function to sort a dictionary by value +def sort_dict_by_value(d): + return dict(sorted(d.items(), key=lambda x: x[1])) + + +# 12. Function to count words in a string +def count_words(s): + return len(s.split()) + + +# 13. Function to check if two strings are anagrams +def are_anagrams(s1, s2): + return sorted(s1) == sorted(s2) + + +# 14. Function to find the intersection of two lists +def list_intersection(lst1, lst2): + return list(set(lst1) & set(lst2)) + + +# 15. Function to calculate the sum of digits of a number +def sum_of_digits4(n): + return sum(int(digit) for digit in str(n)) + + +# 16. Function to generate a random password +def generate_password(length=8): + return "".join(random.choice(string.ascii_letters + string.digits) for _ in range(length)) + + +# 21. Function to find the longest word in a string +def longest_word(s): + return max(s.split(), key=len) + + +# 22. Function to capitalize the first letter of each word +def capitalize_words(s): + return " ".join(word.capitalize() for word in s.split()) + + +# 23. Function to check if a year is a leap year +def is_leap_year(year): + return year % 4 == 0 and (year % 100 != 0 or year % 400 == 0) + + +# 24. Function to calculate the GCD of two numbers +def gcd4(a, b): + return a if b == 0 else gcd(b, a % b) + + +# 25. Function to calculate the LCM of two numbers +def lcm4(a, b): + return a * b // gcd(a, b) + + +# 26. Function to generate a list of squares +def squares(n): + return [i**2 for i in range(1, n + 1)] + + +# 27. Function to generate a list of cubes +def cubes(n): + return [i**3 for i in range(1, n + 1)] + + +# 28. Function to check if a list is sorted +def is_sorted(lst): + return all(lst[i] <= lst[i + 1] for i in range(len(lst) - 1)) + + +# 29. Function to shuffle a list +def shuffle_list(lst): + random.shuffle(lst) + return lst + + +# 30. Function to find the mode of a list +from collections import Counter + + +def find_mode(lst): + return Counter(lst).most_common(1)[0][0] + + +# 31. Function to calculate the mean of a list +def mean(lst): + return sum(lst) / len(lst) + + +# 32. Function to calculate the median of a list +def median(lst): + lst_sorted = sorted(lst) + mid = len(lst) // 2 + return (lst_sorted[mid] + lst_sorted[~mid]) / 2 + + +# 33. Function to calculate the standard deviation of a list +def std_dev(lst): + m = mean(lst) + return math.sqrt(sum((x - m) ** 2 for x in lst) / len(lst)) + + +# 34. Function to find the nth Fibonacci number +def nth_fibonacci(n): + return fibonacci(n)[-1] + + +# 35. Function to check if a number is even +def is_even(n): + return n % 2 == 0 + + +# 36. Function to check if a number is odd +def is_odd(n): + return n % 2 != 0 + + +# 37. Function to convert Celsius to Fahrenheit +def celsius_to_fahrenheit(c): + return (c * 9 / 5) + 32 + + +# 38. Function to convert Fahrenheit to Celsius +def fahrenheit_to_celsius(f): + return (f - 32) * 5 / 9 + + +# 39. Function to calculate the hypotenuse of a right triangle +def hypotenuse(a, b): + return math.sqrt(a**2 + b**2) + + +# 40. Function to calculate the perimeter of a rectangle +def rectangle_perimeter(l, w): + return 2 * (l + w) + + +# 41. Function to calculate the area of a rectangle +def rectangle_area(l, w): + return l * w + + +# 42. Function to calculate the perimeter of a square +def square_perimeter(s): + return 4 * s + + +# 43. Function to calculate the area of a square +def square_area(s): + return s**2 + + +# 44. Function to calculate the perimeter of a circle +def circle_perimeter(r): + return 2 * 3.14159 * r + + +# 45. Function to calculate the volume of a cube +def cube_volume(s): + return s**3 + + +# 46. Function to calculate the volume of a sphere +def sphere_volume1(r): + return (4 / 3) * 3.14159 * r**3 + + +# 47. Function to calculate the volume of a cylinder +def cylinder_volume1(r, h): + return 3.14159 * r**2 * h + + +# 48. Function to calculate the volume of a cone +def cone_volume1(r, h): + return (1 / 3) * 3.14159 * r**2 * h + + +# 49. Function to calculate the surface area of a cube +def cube_surface_area(s): + return 6 * s**2 + + +# 50. Function to calculate the surface area of a sphere +def sphere_surface_area1(r): + return 4 * 3.14159 * r**2 + + +# 51. Function to calculate the surface area of a cylinder +def cylinder_surface_area1(r, h): + return 2 * 3.14159 * r * (r + h) + + +# 52. Function to calculate the surface area of a cone +def cone_surface_area1(r, l): + return 3.14159 * r * (r + l) + + +# 53. Function to generate a list of random numbers +def random_numbers(n, start=0, end=100): + return [random.randint(start, end) for _ in range(n)] + + +# 54. Function to find the index of an element in a list +def find_index(lst, element): + return lst.index(element) if element in lst else -1 + + +# 55. Function to remove an element from a list +def remove_element(lst, element): + return [x for x in lst if x != element] + + +# 56. Function to replace an element in a list +def replace_element(lst, old, new): + return [new if x == old else x for x in lst] + + +# 57. Function to rotate a list by n positions +def rotate_list(lst, n): + return lst[n:] + lst[:n] + + +# 58. Function to find the second largest number in a list +def second_largest(lst): + return sorted(lst)[-2] + + +# 59. Function to find the second smallest number in a list +def second_smallest(lst): + return sorted(lst)[1] + + +# 60. Function to check if all elements in a list are unique +def all_unique(lst): + return len(lst) == len(set(lst)) + + +# 61. Function to find the difference between two lists +def list_difference(lst1, lst2): + return list(set(lst1) - set(lst2)) + + +# 62. Function to find the union of two lists +def list_union(lst1, lst2): + return list(set(lst1) | set(lst2)) + + +# 63. Function to find the symmetric difference of two lists +def symmetric_difference(lst1, lst2): + return list(set(lst1) ^ set(lst2)) + + +# 64. Function to check if a list is a subset of another list +def is_subset(lst1, lst2): + return set(lst1).issubset(set(lst2)) + + +# 65. Function to check if a list is a superset of another list +def is_superset(lst1, lst2): + return set(lst1).issuperset(set(lst2)) + + +# 66. Function to find the frequency of elements in a list +def element_frequency(lst): + return {x: lst.count(x) for x in set(lst)} + + +# 67. Function to find the most frequent element in a list +def most_frequent(lst): + return max(set(lst), key=lst.count) + + +# 68. Function to find the least frequent element in a list +def least_frequent(lst): + return min(set(lst), key=lst.count) + + +# 69. Function to find the average of a list of numbers +def average(lst): + return sum(lst) / len(lst) + + +# 70. Function to find the sum of a list of numbers +def sum_list(lst): + return sum(lst) + + +# 71. Function to find the product of a list of numbers +def product_list(lst): + return math.prod(lst) + + +# 72. Function to find the cumulative sum of a list +def cumulative_sum(lst): + return [sum(lst[: i + 1]) for i in range(len(lst))] + + +# 73. Function to find the cumulative product of a list +def cumulative_product(lst): + return [math.prod(lst[: i + 1]) for i in range(len(lst))] + + +# 74. Function to find the difference between consecutive elements in a list +def consecutive_difference(lst): + return [lst[i + 1] - lst[i] for i in range(len(lst) - 1)] + + +# 75. Function to find the ratio between consecutive elements in a list +def consecutive_ratio(lst): + return [lst[i + 1] / lst[i] for i in range(len(lst) - 1)] + + +# 76. Function to find the cumulative difference of a list +def cumulative_difference(lst): + return [lst[0]] + [lst[i] - lst[i - 1] for i in range(1, len(lst))] + + +# 77. Function to find the cumulative ratio of a list +def cumulative_ratio(lst): + return [lst[0]] + [lst[i] / lst[i - 1] for i in range(1, len(lst))] + + +# 78. Function to find the absolute difference between two lists +def absolute_difference(lst1, lst2): + return [abs(lst1[i] - lst2[i]) for i in range(len(lst1))] + + +# 79. Function to find the absolute sum of two lists +def absolute_sum(lst1, lst2): + return [lst1[i] + lst2[i] for i in range(len(lst1))] + + +# 80. Function to find the absolute product of two lists +def absolute_product(lst1, lst2): + return [lst1[i] * lst2[i] for i in range(len(lst1))] + + +# 81. Function to find the absolute ratio of two lists +def absolute_ratio(lst1, lst2): + return [lst1[i] / lst2[i] for i in range(len(lst1))] + + +# 82. Function to find the absolute cumulative sum of two lists +def absolute_cumulative_sum(lst1, lst2): + return [sum(lst1[: i + 1]) + sum(lst2[: i + 1]) for i in range(len(lst1))] + + +# 83. Function to find the absolute cumulative product of two lists +def absolute_cumulative_product(lst1, lst2): + return [math.prod(lst1[: i + 1]) * math.prod(lst2[: i + 1]) for i in range(len(lst1))] + + +# 84. Function to find the absolute cumulative difference of two lists +def absolute_cumulative_difference(lst1, lst2): + return [sum(lst1[: i + 1]) - sum(lst2[: i + 1]) for i in range(len(lst1))] + + +# 85. Function to find the absolute cumulative ratio of two lists +def absolute_cumulative_ratio(lst1, lst2): + return [sum(lst1[: i + 1]) / sum(lst2[: i + 1]) for i in range(len(lst1))] + + +# 86. Function to find the absolute cumulative sum of a list +def absolute_cumulative_sum_single(lst): + return [sum(lst[: i + 1]) for i in range(len(lst))] + + +# 87. Function to find the absolute cumulative product of a list +def absolute_cumulative_product_single(lst): + return [math.prod(lst[: i + 1]) for i in range(len(lst))] + + +# 88. Function to find the absolute cumulative difference of a list +def absolute_cumulative_difference_single(lst): + return [sum(lst[: i + 1]) - sum(lst[:i]) for i in range(len(lst))] + + +# 89. Function to find the absolute cumulative ratio of a list +def absolute_cumulative_ratio_single(lst): + return [sum(lst[: i + 1]) / sum(lst[:i]) for i in range(len(lst))] + + +# 90. Function to find the absolute cumulative sum of a list with a constant +def absolute_cumulative_sum_constant(lst, constant): + return [sum(lst[: i + 1]) + constant for i in range(len(lst))] + + +# 91. Function to find the absolute cumulative product of a list with a constant +def absolute_cumulative_product_constant(lst, constant): + return [math.prod(lst[: i + 1]) * constant for i in range(len(lst))] + + +# 92. Function to find the absolute cumulative difference of a list with a constant +def absolute_cumulative_difference_constant(lst, constant): + return [sum(lst[: i + 1]) - constant for i in range(len(lst))] + + +# 93. Function to find the absolute cumulative ratio of a list with a constant +def absolute_cumulative_ratio_constant(lst, constant): + return [sum(lst[: i + 1]) / constant for i in range(len(lst))] + + +# 94. Function to find the absolute cumulative sum of a list with a list of constants +def absolute_cumulative_sum_constants(lst, constants): + return [sum(lst[: i + 1]) + constants[i] for i in range(len(lst))] + + +# 95. Function to find the absolute cumulative product of a list with a list of constants +def absolute_cumulative_product_constants(lst, constants): + return [math.prod(lst[: i + 1]) * constants[i] for i in range(len(lst))] + + +# 96. Function to find the absolute cumulative difference of a list with a list of constants +def absolute_cumulative_difference_constants(lst, constants): + return [sum(lst[: i + 1]) - constants[i] for i in range(len(lst))] + + +# 97. Function to find the absolute cumulative ratio of a list with a list of constants +def absolute_cumulative_ratio_constants(lst, constants): + return [sum(lst[: i + 1]) / constants[i] for i in range(len(lst))] + + +# 98. Function to find the absolute cumulative sum of a list with a function +def absolute_cumulative_sum_function(lst, func): + return [sum(lst[: i + 1]) + func(i) for i in range(len(lst))] + + +# 99. Function to find the absolute cumulative product of a list with a function +def absolute_cumulative_product_function(lst, func): + return [math.prod(lst[: i + 1]) * func(i) for i in range(len(lst))] + + +# 100. Function to find the absolute cumulative difference of a list with a function +def absolute_cumulative_difference_function(lst, func): + return [sum(lst[: i + 1]) - func(i) for i in range(len(lst))] + + +# 101. Function to find the absolute cumulative ratio of a list with a function +def absolute_cumulative_ratio_function(lst, func): + return [sum(lst[: i + 1]) / func(i) for i in range(len(lst))] + + +# 102. Function to find the absolute cumulative sum of a list with a lambda function +def absolute_cumulative_sum_lambda(lst, func): + return [sum(lst[: i + 1]) + func(i) for i in range(len(lst))] + + +# 103. Function to find the absolute cumulative product of a list with a lambda function +def absolute_cumulative_product_lambda(lst, func): + return [math.prod(lst[: i + 1]) * func(i) for i in range(len(lst))] + + +# 104. Function to find the absolute cumulative difference of a list with a lambda function +def absolute_cumulative_difference_lambda(lst, func): + return [sum(lst[: i + 1]) - func(i) for i in range(len(lst))] + + +# 105. Function to find the absolute cumulative ratio of a list with a lambda function +def absolute_cumulative_ratio_lambda(lst, func): + return [sum(lst[: i + 1]) / func(i) for i in range(len(lst))] + + +# 134. Function to check if a string is a valid email address +def is_valid_email1(email): + import re + + pattern = r"^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$" + return bool(re.match(pattern, email)) + + +# 135. Function to generate a list of prime numbers up to a given limit +def generate_primes1(limit): + primes = [] + for num in range(2, limit + 1): + if all(num % i != 0 for i in range(2, int(num**0.5) + 1)): + primes.append(num) + return primes + + +# 136. Function to calculate the nth Fibonacci number using recursion +def nth_fibonacci_recursive1(n): + if n <= 0: + return 0 + elif n == 1: + return 1 + else: + return nth_fibonacci_recursive(n - 1) + nth_fibonacci_recursive(n - 2) + + +# 137. Function to calculate the nth Fibonacci number using iteration +def nth_fibonacci_iterative1(n): + a, b = 0, 1 + for _ in range(n): + a, b = b, a + b + return a + + +# 138. Function to calculate the factorial of a number using iteration +def factorial_iterative1(n): + result = 1 + for i in range(1, n + 1): + result *= i + return result + + +# 139. Function to calculate the factorial of a number using recursion +def factorial_recursive1(n): + if n <= 1: + return 1 + else: + return n * factorial_recursive(n - 1) + + +# 140. Function to calculate the sum of all elements in a nested list +def sum_nested_list1(lst): + total = 0 + for element in lst: + if isinstance(element, list): + total += sum_nested_list(element) + else: + total += element + return total + + +# 141. Function to flatten a nested list +def flatten_nested_list1(lst): + flattened = [] + for element in lst: + if isinstance(element, list): + flattened.extend(flatten_nested_list(element)) + else: + flattened.append(element) + return flattened + + +# 142. Function to find the longest word in a string +def longest_word_in_string1(s): + words = s.split() + longest = "" + for word in words: + if len(word) > len(longest): + longest = word + return longest + + +# 143. Function to count the frequency of each character in a string +def character_frequency1(s): + frequency = {} + for char in s: + if char in frequency: + frequency[char] += 1 + else: + frequency[char] = 1 + return frequency + + +# 144. Function to check if a number is a perfect square +def is_perfect_square1(n): + if n < 0: + return False + sqrt = int(n**0.5) + return sqrt * sqrt == n + + +# 145. Function to check if a number is a perfect cube +def is_perfect_cube1(n): + if n < 0: + return False + cube_root = round(n ** (1 / 3)) + return cube_root**3 == n + + +# 146. Function to calculate the sum of squares of the first n natural numbers +def sum_of_squares1(n): + return sum(i**2 for i in range(1, n + 1)) + + +# 147. Function to calculate the sum of cubes of the first n natural numbers +def sum_of_cubes1(n): + return sum(i**3 for i in range(1, n + 1)) + + +# 148. Function to calculate the sum of the digits of a number +def sum_of_digits1(n): + total = 0 + while n > 0: + total += n % 10 + n = n // 10 + return total + + +# 149. Function to calculate the product of the digits of a number +def product_of_digits1(n): + product = 1 + while n > 0: + product *= n % 10 + n = n // 10 + return product + + +# 150. Function to reverse a number +def reverse_number1(n): + reversed_num = 0 + while n > 0: + reversed_num = reversed_num * 10 + n % 10 + n = n // 10 + return reversed_num + + +# 151. Function to check if a number is a palindrome +def is_number_palindrome1(n): + return n == reverse_number(n) + + +# 152. Function to generate a list of all divisors of a number +def divisors1(n): + divisors = [] + for i in range(1, n + 1): + if n % i == 0: + divisors.append(i) + return divisors + + +# 153. Function to check if a number is abundant +def is_abundant1(n): + return sum(divisors(n)) - n > n + + +# 154. Function to check if a number is deficient +def is_deficient1(n): + return sum(divisors(n)) - n < n + + +# 155. Function to check if a number is perfect +def is_perfect1(n): + return sum(divisors(n)) - n == n + + +# 156. Function to calculate the greatest common divisor (GCD) of two numbers +def gcd1(a, b): + while b: + a, b = b, a % b + return a + + +# 157. Function to calculate the least common multiple (LCM) of two numbers +def lcm1(a, b): + return a * b // gcd(a, b) + + +# 158. Function to generate a list of the first n triangular numbers +def triangular_numbers1(n): + return [i * (i + 1) // 2 for i in range(1, n + 1)] + + +# 159. Function to generate a list of the first n square numbers +def square_numbers1(n): + return [i**2 for i in range(1, n + 1)] + + +# 160. Function to generate a list of the first n cube numbers +def cube_numbers1(n): + return [i**3 for i in range(1, n + 1)] + + +# 161. Function to calculate the area of a triangle given its base and height +def triangle_area1(base, height): + return 0.5 * base * height + + +# 162. Function to calculate the area of a trapezoid given its bases and height +def trapezoid_area1(base1, base2, height): + return 0.5 * (base1 + base2) * height + + +# 163. Function to calculate the area of a parallelogram given its base and height +def parallelogram_area1(base, height): + return base * height + + +# 164. Function to calculate the area of a rhombus given its diagonals +def rhombus_area1(diagonal1, diagonal2): + return 0.5 * diagonal1 * diagonal2 + + +# 165. Function to calculate the area of a regular polygon given the number of sides and side length +def regular_polygon_area1(n, side_length): + import math + + return (n * side_length**2) / (4 * math.tan(math.pi / n)) + + +# 166. Function to calculate the perimeter of a regular polygon given the number of sides and side length +def regular_polygon_perimeter1(n, side_length): + return n * side_length + + +# 167. Function to calculate the volume of a rectangular prism given its dimensions +def rectangular_prism_volume1(length, width, height): + return length * width * height + + +# 168. Function to calculate the surface area of a rectangular prism given its dimensions +def rectangular_prism_surface_area1(length, width, height): + return 2 * (length * width + width * height + height * length) + + +# 169. Function to calculate the volume of a pyramid given its base area and height +def pyramid_volume1(base_area, height): + return (1 / 3) * base_area * height + + +# 170. Function to calculate the surface area of a pyramid given its base area and slant height +def pyramid_surface_area1(base_area, slant_height): + return base_area + (1 / 2) * base_area * slant_height + + +# 171. Function to calculate the volume of a cone given its radius and height +def cone_volume2(radius, height): + return (1 / 3) * 3.14159 * radius**2 * height + + +# 172. Function to calculate the surface area of a cone given its radius and slant height +def cone_surface_area2(radius, slant_height): + return 3.14159 * radius * (radius + slant_height) + + +# 173. Function to calculate the volume of a sphere given its radius +def sphere_volume2(radius): + return (4 / 3) * 3.14159 * radius**3 + + +# 174. Function to calculate the surface area of a sphere given its radius +def sphere_surface_area2(radius): + return 4 * 3.14159 * radius**2 + + +# 175. Function to calculate the volume of a cylinder given its radius and height +def cylinder_volume2(radius, height): + return 3.14159 * radius**2 * height + + +# 176. Function to calculate the surface area of a cylinder given its radius and height +def cylinder_surface_area2(radius, height): + return 2 * 3.14159 * radius * (radius + height) + + +# 177. Function to calculate the volume of a torus given its major and minor radii +def torus_volume2(major_radius, minor_radius): + return 2 * 3.14159**2 * major_radius * minor_radius**2 + + +# 178. Function to calculate the surface area of a torus given its major and minor radii +def torus_surface_area2(major_radius, minor_radius): + return 4 * 3.14159**2 * major_radius * minor_radius + + +# 179. Function to calculate the volume of an ellipsoid given its semi-axes +def ellipsoid_volume2(a, b, c): + return (4 / 3) * 3.14159 * a * b * c + + +# 180. Function to calculate the surface area of an ellipsoid given its semi-axes +def ellipsoid_surface_area2(a, b, c): + # Approximation for surface area of an ellipsoid + p = 1.6075 + return 4 * 3.14159 * ((a**p * b**p + a**p * c**p + b**p * c**p) / 3) ** (1 / p) + + +# 181. Function to calculate the volume of a paraboloid given its radius and height +def paraboloid_volume2(radius, height): + return (1 / 2) * 3.14159 * radius**2 * height + + +# 182. Function to calculate the surface area of a paraboloid given its radius and height +def paraboloid_surface_area2(radius, height): + # Approximation for surface area of a paraboloid + return (3.14159 * radius / (6 * height**2)) * ( + (radius**2 + 4 * height**2) ** (3 / 2) - radius**3 + ) + + +# 183. Function to calculate the volume of a hyperboloid given its radii and height +def hyperboloid_volume2(radius1, radius2, height): + return (1 / 3) * 3.14159 * height * (radius1**2 + radius1 * radius2 + radius2**2) + + +# 184. Function to calculate the surface area of a hyperboloid given its radii and height +def hyperboloid_surface_area2(radius1, radius2, height): + # Approximation for surface area of a hyperboloid + return 3.14159 * (radius1 + radius2) * math.sqrt((radius1 - radius2) ** 2 + height**2) + + +# 185. Function to calculate the volume of a tetrahedron given its edge length +def tetrahedron_volume2(edge_length): + return (edge_length**3) / (6 * math.sqrt(2)) + + +# 186. Function to calculate the surface area of a tetrahedron given its edge length +def tetrahedron_surface_area2(edge_length): + return math.sqrt(3) * edge_length**2 + + +# 187. Function to calculate the volume of an octahedron given its edge length +def octahedron_volume2(edge_length): + return (math.sqrt(2) / 3) * edge_length**3 + + +# 134. Function to check if a string is a valid email address +def is_valid_email(email): + import re + + pattern = r"^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$" + return bool(re.match(pattern, email)) + + +# 135. Function to generate a list of prime numbers up to a given limit +def generate_primes(limit): + primes = [] + for num in range(2, limit + 1): + if all(num % i != 0 for i in range(2, int(num**0.5) + 1)): + primes.append(num) + return primes + + +# 136. Function to calculate the nth Fibonacci number using recursion +def nth_fibonacci_recursive(n): + if n <= 0: + return 0 + elif n == 1: + return 1 + else: + return nth_fibonacci_recursive(n - 1) + nth_fibonacci_recursive(n - 2) + + +# 137. Function to calculate the nth Fibonacci number using iteration +def nth_fibonacci_iterative(n): + a, b = 0, 1 + for _ in range(n): + a, b = b, a + b + return a + + +# 138. Function to calculate the factorial of a number using iteration +def factorial_iterative(n): + result = 1 + for i in range(1, n + 1): + result *= i + return result + + +# 139. Function to calculate the factorial of a number using recursion +def factorial_recursive(n): + if n <= 1: + return 1 + else: + return n * factorial_recursive(n - 1) + + +# 140. Function to calculate the sum of all elements in a nested list +def sum_nested_list(lst): + total = 0 + for element in lst: + if isinstance(element, list): + total += sum_nested_list(element) + else: + total += element + return total + + +# 141. Function to flatten a nested list +def flatten_nested_list(lst): + flattened = [] + for element in lst: + if isinstance(element, list): + flattened.extend(flatten_nested_list(element)) + else: + flattened.append(element) + return flattened + + +# 142. Function to find the longest word in a string +def longest_word_in_string(s): + words = s.split() + longest = "" + for word in words: + if len(word) > len(longest): + longest = word + return longest + + +# 143. Function to count the frequency of each character in a string +def character_frequency3(s): + frequency = {} + for char in s: + if char in frequency: + frequency[char] += 1 + else: + frequency[char] = 1 + return frequency + + +# 144. Function to check if a number is a perfect square +def is_perfect_square3(n): + if n < 0: + return False + sqrt = int(n**0.5) + return sqrt * sqrt == n + + +# 145. Function to check if a number is a perfect cube +def is_perfect_cube3(n): + if n < 0: + return False + cube_root = round(n ** (1 / 3)) + return cube_root**3 == n + + +# 146. Function to calculate the sum of squares of the first n natural numbers +def sum_of_squares3(n): + return sum(i**2 for i in range(1, n + 1)) + + +# 147. Function to calculate the sum of cubes of the first n natural numbers +def sum_of_cubes3(n): + return sum(i**3 for i in range(1, n + 1)) + + +# 148. Function to calculate the sum of the digits of a number +def sum_of_digits3(n): + total = 0 + while n > 0: + total += n % 10 + n = n // 10 + return total + + +# 149. Function to calculate the product of the digits of a number +def product_of_digits3(n): + product = 1 + while n > 0: + product *= n % 10 + n = n // 10 + return product + + +# 150. Function to reverse a number +def reverse_number3(n): + reversed_num = 0 + while n > 0: + reversed_num = reversed_num * 10 + n % 10 + n = n // 10 + return reversed_num + + +# 151. Function to check if a number is a palindrome +def is_number_palindrome3(n): + return n == reverse_number(n) + + +# 152. Function to generate a list of all divisors of a number +def divisors3(n): + divisors = [] + for i in range(1, n + 1): + if n % i == 0: + divisors.append(i) + return divisors + + +# 153. Function to check if a number is abundant +def is_abundant3(n): + return sum(divisors(n)) - n > n + + +# 154. Function to check if a number is deficient +def is_deficient3(n): + return sum(divisors(n)) - n < n + + +# 155. Function to check if a number is perfect +def is_perfect3(n): + return sum(divisors(n)) - n == n + + +# 156. Function to calculate the greatest common divisor (GCD) of two numbers +def gcd3(a, b): + while b: + a, b = b, a % b + return a + + +# 157. Function to calculate the least common multiple (LCM) of two numbers +def lcm3(a, b): + return a * b // gcd(a, b) + + +# 158. Function to generate a list of the first n triangular numbers +def triangular_numbers3(n): + return [i * (i + 1) // 2 for i in range(1, n + 1)] + + +# 159. Function to generate a list of the first n square numbers +def square_numbers3(n): + return [i**2 for i in range(1, n + 1)] + + +# 160. Function to generate a list of the first n cube numbers +def cube_numbers3(n): + return [i**3 for i in range(1, n + 1)] + + +# 161. Function to calculate the area of a triangle given its base and height +def triangle_area3(base, height): + return 0.5 * base * height + + +# 162. Function to calculate the area of a trapezoid given its bases and height +def trapezoid_area3(base1, base2, height): + return 0.5 * (base1 + base2) * height + + +# 163. Function to calculate the area of a parallelogram given its base and height +def parallelogram_area3(base, height): + return base * height + + +# 164. Function to calculate the area of a rhombus given its diagonals +def rhombus_area3(diagonal1, diagonal2): + return 0.5 * diagonal1 * diagonal2 + + +# 165. Function to calculate the area of a regular polygon given the number of sides and side length +def regular_polygon_area3(n, side_length): + import math + + return (n * side_length**2) / (4 * math.tan(math.pi / n)) + + +# 166. Function to calculate the perimeter of a regular polygon given the number of sides and side length +def regular_polygon_perimeter3(n, side_length): + return n * side_length + + +# 167. Function to calculate the volume of a rectangular prism given its dimensions +def rectangular_prism_volume3(length, width, height): + return length * width * height + + +# 168. Function to calculate the surface area of a rectangular prism given its dimensions +def rectangular_prism_surface_area3(length, width, height): + return 2 * (length * width + width * height + height * length) + + +# 169. Function to calculate the volume of a pyramid given its base area and height +def pyramid_volume3(base_area, height): + return (1 / 3) * base_area * height + + +# 170. Function to calculate the surface area of a pyramid given its base area and slant height +def pyramid_surface_area3(base_area, slant_height): + return base_area + (1 / 2) * base_area * slant_height + + +# 171. Function to calculate the volume of a cone given its radius and height +def cone_volume3(radius, height): + return (1 / 3) * 3.14159 * radius**2 * height + + +# 172. Function to calculate the surface area of a cone given its radius and slant height +def cone_surface_area3(radius, slant_height): + return 3.14159 * radius * (radius + slant_height) + + +# 173. Function to calculate the volume of a sphere given its radius +def sphere_volume3(radius): + return (4 / 3) * 3.14159 * radius**3 + + +# 174. Function to calculate the surface area of a sphere given its radius +def sphere_surface_area3(radius): + return 4 * 3.14159 * radius**2 + + +# 175. Function to calculate the volume of a cylinder given its radius and height +def cylinder_volume3(radius, height): + return 3.14159 * radius**2 * height + + +# 176. Function to calculate the surface area of a cylinder given its radius and height +def cylinder_surface_area3(radius, height): + return 2 * 3.14159 * radius * (radius + height) + + +# 177. Function to calculate the volume of a torus given its major and minor radii +def torus_volume3(major_radius, minor_radius): + return 2 * 3.14159**2 * major_radius * minor_radius**2 + + +# 178. Function to calculate the surface area of a torus given its major and minor radii +def torus_surface_area3(major_radius, minor_radius): + return 4 * 3.14159**2 * major_radius * minor_radius + + +# 179. Function to calculate the volume of an ellipsoid given its semi-axes +def ellipsoid_volume3(a, b, c): + return (4 / 3) * 3.14159 * a * b * c + + +# 180. Function to calculate the surface area of an ellipsoid given its semi-axes +def ellipsoid_surface_area3(a, b, c): + # Approximation for surface area of an ellipsoid + p = 1.6075 + return 4 * 3.14159 * ((a**p * b**p + a**p * c**p + b**p * c**p) / 3) ** (1 / p) + + +# 181. Function to calculate the volume of a paraboloid given its radius and height +def paraboloid_volume3(radius, height): + return (1 / 2) * 3.14159 * radius**2 * height + + +# 182. Function to calculate the surface area of a paraboloid given its radius and height +def paraboloid_surface_area3(radius, height): + # Approximation for surface area of a paraboloid + return (3.14159 * radius / (6 * height**2)) * ( + (radius**2 + 4 * height**2) ** (3 / 2) - radius**3 + ) + + +# 183. Function to calculate the volume of a hyperboloid given its radii and height +def hyperboloid_volume(radius1, radius2, height): + return (1 / 3) * 3.14159 * height * (radius1**2 + radius1 * radius2 + radius2**2) + + +# 184. Function to calculate the surface area of a hyperboloid given its radii and height +def hyperboloid_surface_area(radius1, radius2, height): + # Approximation for surface area of a hyperboloid + return 3.14159 * (radius1 + radius2) * math.sqrt((radius1 - radius2) ** 2 + height**2) + + +# 185. Function to calculate the volume of a tetrahedron given its edge length +def tetrahedron_volume(edge_length): + return (edge_length**3) / (6 * math.sqrt(2)) + + +# 186. Function to calculate the surface area of a tetrahedron given its edge length +def tetrahedron_surface_area(edge_length): + return math.sqrt(3) * edge_length**2 + + +# 187. Function to calculate the volume of an octahedron given its edge length +def octahedron_volume(edge_length): + return (math.sqrt(2) / 3) * edge_length**3 + + +# 188. Function to calculate the surface area of an octahedron given its edge length +def octahedron_surface_area(edge_length): + return 2 * math.sqrt(3) * edge_length**2 + + +# 189. Function to calculate the volume of a dodecahedron given its edge length +def dodecahedron_volume(edge_length): + return (15 + 7 * math.sqrt(5)) / 4 * edge_length**3 + + +# 190. Function to calculate the surface area of a dodecahedron given its edge length +def dodecahedron_surface_area(edge_length): + return 3 * math.sqrt(25 + 10 * math.sqrt(5)) * edge_length**2 + + +# 191. Function to calculate the volume of an icosahedron given its edge length +def icosahedron_volume(edge_length): + return (5 * (3 + math.sqrt(5))) / 12 * edge_length**3 + + +# 192. Function to calculate the surface area of an icosahedron given its edge length +def icosahedron_surface_area(edge_length): + return 5 * math.sqrt(3) * edge_length**2 + + +# 193. Function to calculate the volume of a frustum given its radii and height +def frustum_volume(radius1, radius2, height): + return (1 / 3) * 3.14159 * height * (radius1**2 + radius1 * radius2 + radius2**2) + + +# 194. Function to calculate the surface area of a frustum given its radii and height +def frustum_surface_area(radius1, radius2, height): + slant_height = math.sqrt((radius1 - radius2) ** 2 + height**2) + return 3.14159 * (radius1 + radius2) * slant_height + 3.14159 * (radius1**2 + radius2**2) + + +# 195. Function to calculate the volume of a spherical cap given its radius and height +def spherical_cap_volume(radius, height): + return (1 / 3) * 3.14159 * height**2 * (3 * radius - height) + + +# 196. Function to calculate the surface area of a spherical cap given its radius and height +def spherical_cap_surface_area(radius, height): + return 2 * 3.14159 * radius * height + + +# 197. Function to calculate the volume of a spherical segment given its radii and height +def spherical_segment_volume(radius1, radius2, height): + return (1 / 6) * 3.14159 * height * (3 * radius1**2 + 3 * radius2**2 + height**2) + + +# 198. Function to calculate the surface area of a spherical segment given its radii and height +def spherical_segment_surface_area(radius1, radius2, height): + return 2 * 3.14159 * radius1 * height + 3.14159 * (radius1**2 + radius2**2) + + +# 199. Function to calculate the volume of a spherical wedge given its radius and angle +def spherical_wedge_volume(radius, angle): + return (2 / 3) * radius**3 * angle + + +# 200. Function to calculate the surface area of a spherical wedge given its radius and angle +def spherical_wedge_surface_area(radius, angle): + return 2 * radius**2 * angle + + +# 201. Function to calculate the volume of a spherical sector given its radius and height +def spherical_sector_volume(radius, height): + return (2 / 3) * 3.14159 * radius**2 * height + + +# 202. Function to calculate the surface area of a spherical sector given its radius and height +def spherical_sector_surface_area(radius, height): + return 3.14159 * radius * (2 * height + math.sqrt(radius**2 + height**2)) + + +# 203. Function to calculate the volume of a spherical cone given its radius and height +def spherical_cone_volume(radius, height): + return (1 / 3) * 3.14159 * radius**2 * height + + +# 204. Function to calculate the surface area of a spherical cone given its radius and height +def spherical_cone_surface_area(radius, height): + return 3.14159 * radius * (radius + math.sqrt(radius**2 + height**2)) + + +# 205. Function to calculate the volume of a spherical pyramid given its base area and height +def spherical_pyramid_volume(base_area, height): + return (1 / 3) * base_area * height + + +# 206. Function to calculate the surface area of a spherical pyramid given its base area and slant height +def spherical_pyramid_surface_area(base_area, slant_height): + return base_area + (1 / 2) * base_area * slant_height + + +# 207. Function to calculate the volume of a spherical frustum given its radii and height +def spherical_frustum_volume(radius1, radius2, height): + return (1 / 6) * 3.14159 * height * (3 * radius1**2 + 3 * radius2**2 + height**2) + + +# 208. Function to calculate the surface area of a spherical frustum given its radii and height +def spherical_frustum_surface_area(radius1, radius2, height): + return 2 * 3.14159 * radius1 * height + 3.14159 * (radius1**2 + radius2**2) + + +# 209. Function to calculate the volume of a spherical segment given its radius and height +def spherical_segment_volume_single(radius, height): + return (1 / 6) * 3.14159 * height * (3 * radius**2 + height**2) + + +# 210. Function to calculate the surface area of a spherical segment given its radius and height +def spherical_segment_surface_area_single(radius, height): + return 2 * 3.14159 * radius * height + 3.14159 * radius**2 + + +# 1. Function that generates a random number and does nothing with it +def useless_function_1(): + import random + + num = random.randint(1, 100) + for i in range(10): + num += i + if num % 2 == 0: + num -= 1 + else: + num += 1 + return None + + +# 2. Function that creates a list and appends meaningless values +def useless_function_2(): + lst = [] + for i in range(10): + lst.append(i * 2) + if i % 3 == 0: + lst.pop() + else: + lst.insert(0, i) + return lst + + +# 3. Function that calculates a sum but discards it +def useless_function_3(): + total = 0 + for i in range(10): + total += i + if total > 20: + total = 0 + else: + total += 1 + return None + + +# 4. Function that prints numbers but returns nothing +def useless_function_4(): + for i in range(10): + print(i) + if i % 2 == 0: + print("Even") + else: + print("Odd") + return None + + +# 5. Function that creates a dictionary and fills it with useless data +def useless_function_5(): + d = {} + for i in range(10): + d[i] = i * 2 + if i % 4 == 0: + d.pop(i) + else: + d[i] = None + return d + + +# 6. Function that generates random strings and discards them +def useless_function_6(): + import random + import string + + for _ in range(10): + s = "".join(random.choice(string.ascii_letters) for _ in range(10)) + if len(s) > 5: + s = s[::-1] + else: + s = s.upper() + return None + + +# 7. Function that loops endlessly but does nothing +def useless_function_7(): + i = 0 + while i < 10: + i += 1 + if i == 5: + i = 0 + else: + pass + return None + + +# 8. Function that creates a tuple and modifies it (but doesn't return it) +def useless_function_8(): + t = tuple(range(10)) + for i in range(10): + if i in t: + t = t[:i] + (i * 2,) + t[i + 1 :] + else: + t = t + (i,) + return None + + +# 9. Function that calculates a factorial but doesn't return it +def useless_function_9(): + def factorial(n): + if n <= 1: + return 1 + else: + return n * factorial(n - 1) + + for i in range(10): + factorial(i) + return None + + +# 10. Function that generates a list of squares but discards it +def useless_function_10(): + squares = [i**2 for i in range(10)] + for i in range(10): + if squares[i] % 2 == 0: + squares[i] = 1 + else: + squares[i] = 0 + return None + + +# 11. Function that creates a set and performs useless operations +def useless_function_11(): + s = set() + for i in range(10): + s.add(i) + if i % 3 == 0: + s.discard(i) + else: + s.add(i * 2) + return None + + +# 12. Function that reverses a string but doesn't return it +def useless_function_12(): + s = "abcdefghij" + reversed_s = s[::-1] + for i in range(10): + if i % 2 == 0: + reversed_s = reversed_s.upper() + else: + reversed_s = reversed_s.lower() + return None + + +# 13. Function that checks if a number is prime but does nothing with the result +def useless_function_13(): + def is_prime(n): + if n <= 1: + return False + for i in range(2, int(n**0.5) + 1): + if n % i == 0: + return False + return True + + for i in range(10): + is_prime(i) + return None + + +# 14. Function that creates a list of random numbers and discards it +def useless_function_14(): + import random + + lst = [random.randint(1, 100) for _ in range(10)] + for i in range(10): + if lst[i] > 50: + lst[i] = 0 + else: + lst[i] = 1 + return None + + +# 15. Function that calculates the sum of a range but doesn't return it +def useless_function_15(): + total = sum(range(10)) + for i in range(10): + if total > 20: + total -= i + else: + total += i + return None + + +# 16. Function that creates a list of tuples and discards it +def useless_function_16(): + lst = [(i, i * 2) for i in range(10)] + for i in range(10): + if lst[i][0] % 2 == 0: + lst[i] = (0, 0) + else: + lst[i] = (1, 1) + return None + + +# 17. Function that generates a random float and does nothing with it +def useless_function_17(): + import random + + num = random.uniform(0, 1) + for i in range(10): + num += 0.1 + if num > 1: + num = 0 + else: + num *= 2 + return None + + +# 18. Function that creates a list of strings and discards it +def useless_function_18(): + lst = ["hello" for _ in range(10)] + for i in range(10): + if len(lst[i]) > 3: + lst[i] = lst[i].upper() + else: + lst[i] = lst[i].lower() + return None + + +# 19. Function that calculates the product of a list but doesn't return it +def useless_function_19(): + import math + + lst = [i for i in range(1, 11)] + product = math.prod(lst) + for i in range(10): + if product > 1000: + product = 0 + else: + product += 1 + return None + + +# 20. Function that creates a dictionary of squares and discards it +def useless_function_20(): + d = {i: i**2 for i in range(10)} + for i in range(10): + if d[i] % 2 == 0: + d[i] = 1 + else: + d[i] = 0 + return None + + +# 21. Function that generates a random boolean and does nothing with it +def useless_function_21(): + import random + + b = random.choice([True, False]) + for i in range(10): + if b: + b = False + else: + b = True + return None + + +# 22. Function that creates a list of lists and discards it +def useless_function_22(): + lst = [[i for i in range(10)] for _ in range(10)] + for i in range(10): + if len(lst[i]) > 5: + lst[i] = [] + else: + lst[i] = [0] + return None + + +# 23. Function that calculates the average of a list but doesn't return it +def useless_function_23(): + lst = [i for i in range(10)] + avg = sum(lst) / len(lst) + for i in range(10): + if avg > 5: + avg -= 1 + else: + avg += 1 + return None + + +# 24. Function that creates a list of random floats and discards it +def useless_function_24(): + import random + + lst = [random.uniform(0, 1) for _ in range(10)] + for i in range(10): + if lst[i] > 0.5: + lst[i] = 0 + else: + lst[i] = 1 + return None + + +# 25. Function that generates a random integer and does nothing with it +def useless_function_25(): + import random + + num = random.randint(1, 100) + for i in range(10): + if num % 2 == 0: + num += 1 + else: + num -= 1 + return None + + +# 26. Function that creates a list of dictionaries and discards it +def useless_function_26(): + lst = [{i: i * 2} for i in range(10)] + for i in range(10): + if i % 3 == 0: + lst[i] = {} + else: + lst[i] = {0: 0} + return None + + +# 27. Function that calculates the sum of squares but doesn't return it +def useless_function_27(): + total = sum(i**2 for i in range(10)) + for i in range(10): + if total > 100: + total = 0 + else: + total += 1 + return None + + +# 28. Function that creates a list of sets and discards it +def useless_function_28(): + lst = [set(range(i)) for i in range(10)] + for i in range(10): + if len(lst[i]) > 3: + lst[i] = set() + else: + lst[i] = {0} + return None + + +# 29. Function that generates a random string and does nothing with it +def useless_function_29(): + import random + import string + + s = "".join(random.choice(string.ascii_letters) for _ in range(10)) + for i in range(10): + if s[i] == "a": + s = s.upper() + else: + s = s.lower() + return None + + +# 30. Function that creates a list of tuples and discards it +def useless_function_30(): + lst = [(i, i * 2) for i in range(10)] + for i in range(10): + if lst[i][0] % 2 == 0: + lst[i] = (0, 0) + else: + lst[i] = (1, 1) + return None + + +# 31. Function that calculates the sum of cubes but doesn't return it +def useless_function_31(): + total = sum(i**3 for i in range(10)) + for i in range(10): + if total > 1000: + total = 0 + else: + total += 1 + return None + + +# 32. Function that creates a list of random booleans and discards it +def useless_function_32(): + import random + + lst = [random.choice([True, False]) for _ in range(10)] + for i in range(10): + if lst[i]: + lst[i] = False + else: + lst[i] = True + return None + + +# 33. Function that generates a random float and does nothing with it +def useless_function_33(): + import random + + num = random.uniform(0, 1) + for i in range(10): + if num > 0.5: + num = 0 + else: + num = 1 + return None + + +# 34. Function that creates a list of lists and discards it +def useless_function_34(): + lst = [[i for i in range(10)] for _ in range(10)] + for i in range(10): + if len(lst[i]) > 5: + lst[i] = [] + else: + lst[i] = [0] + return None + + +# 35. Function that calculates the average of a list but doesn't return it +def useless_function_35(): + lst = [i for i in range(10)] + avg = sum(lst) / len(lst) + for i in range(10): + if avg > 5: + avg -= 1 + else: + avg += 1 + return None + + +# 36. Function that creates a list of random floats and discards it +def useless_function_36(): + import random + + lst = [random.uniform(0, 1) for _ in range(10)] + for i in range(10): + if lst[i] > 0.5: + lst[i] = 0 + else: + lst[i] = 1 + return None + + +# 37. Function that generates a random integer and does nothing with it +def useless_function_37(): + import random + + num = random.randint(1, 100) + for i in range(10): + if num % 2 == 0: + num += 1 + else: + num -= 1 + return None + + +# 38. Function that creates a list of dictionaries and discards it +def useless_function_38(): + lst = [{i: i * 2} for i in range(10)] + for i in range(10): + if i % 3 == 0: + lst[i] = {} + else: + lst[i] = {0: 0} + return None + + +# 39. Function that calculates the sum of squares but doesn't return it +def useless_function_39(): + total = sum(i**2 for i in range(10)) + for i in range(10): + if total > 100: + total = 0 + else: + total += 1 + return None + + +# 40. Function that creates a list of sets and discards it +def useless_function_40(): + lst = [set(range(i)) for i in range(10)] + for i in range(10): + if len(lst[i]) > 3: + lst[i] = set() + else: + lst[i] = {0} + return None + + +# 41. Function that generates a random string and does nothing with it +def useless_function_41(): + import random + import string + + s = "".join(random.choice(string.ascii_letters) for _ in range(10)) + for i in range(10): + if s[i] == "a": + s = s.upper() + else: + s = s.lower() + return None + + +# 42. Function that creates a list of tuples and discards it +def useless_function_42(): + lst = [(i, i * 2) for i in range(10)] + for i in range(10): + if lst[i][0] % 2 == 0: + lst[i] = (0, 0) + else: + lst[i] = (1, 1) + return None + + +# 43. Function that calculates the sum of cubes but doesn't return it +def useless_function_43(): + total = sum(i**3 for i in range(10)) + for i in range(10): + if total > 1000: + total = 0 + else: + total += 1 + return None + + +# 44. Function that creates a list of random booleans and discards it +def useless_function_44(): + import random + + lst = [random.choice([True, False]) for _ in range(10)] + for i in range(10): + if lst[i]: + lst[i] = False + else: + lst[i] = True + return None + + +# 45. Function that generates a random float and does nothing with it +def useless_function_45(): + import random + + num = random.uniform(0, 1) + for i in range(10): + if num > 0.5: + num = 0 + else: + num = 1 + return None + + +# 46. Function that creates a list of lists and discards it +def useless_function_46(): + lst = [[i for i in range(10)] for _ in range(10)] + for i in range(10): + if len(lst[i]) > 5: + lst[i] = [] + else: + lst[i] = [0] + return None + + +# 47. Function that calculates the average of a list but doesn't return it +def useless_function_47(): + lst = [i for i in range(10)] + avg = sum(lst) / len(lst) + for i in range(10): + if avg > 5: + avg -= 1 + else: + avg += 1 + return None + + +# 48. Function that creates a list of random floats and discards it +def useless_function_48(): + import random + + lst = [random.uniform(0, 1) for _ in range(10)] + for i in range(10): + if lst[i] > 0.5: + lst[i] = 0 + else: + lst[i] = 1 + return None + + +# 49. Function that generates a random integer and does nothing with it +def useless_function_49(): + import random + + num = random.randint(1, 100) + for i in range(10): + if num % 2 == 0: + num += 1 + else: + num -= 1 + return None + + +# 50. Function that creates a list of dictionaries and discards it +def useless_function_50(): + lst = [{i: i * 2} for i in range(10)] + for i in range(10): + if i % 3 == 0: + lst[i] = {} + else: + lst[i] = {0: 0} + return None + + +# 51. Function that generates a random number and performs useless operations +def useless_function_51(): + import random + + num = random.randint(1, 100) + for i in range(10): + num += i + if num % 2 == 0: + num -= random.randint(1, 10) + else: + num += random.randint(1, 10) + return None + + +# 52. Function that creates a list of random strings and discards it +def useless_function_52(): + import random + import string + + lst = ["".join(random.choice(string.ascii_letters) for _ in range(10))] + for i in range(10): + if len(lst[i]) > 5: + lst[i] = lst[i].upper() + else: + lst[i] = lst[i].lower() + return None + + +# 53. Function that calculates the sum of a range but does nothing with it +def useless_function_53(): + total = sum(range(10)) + for i in range(10): + if total > 20: + total -= i + else: + total += i + return None + + +# 54. Function that creates a list of tuples and discards it +def useless_function_54(): + lst = [(i, i * 2) for i in range(10)] + for i in range(10): + if lst[i][0] % 2 == 0: + lst[i] = (0, 0) + else: + lst[i] = (1, 1) + return None + + +# 55. Function that generates a random float and does nothing with it +def useless_function_55(): + import random + + num = random.uniform(0, 1) + for i in range(10): + if num > 0.5: + num = 0 + else: + num = 1 + return None + + +# 56. Function that creates a list of lists and discards it +def useless_function_56(): + lst = [[i for i in range(10)] for _ in range(10)] + for i in range(10): + if len(lst[i]) > 5: + lst[i] = [] + else: + lst[i] = [0] + return None + + +# 57. Function that calculates the average of a list but doesn't return it +def useless_function_57(): + lst = [i for i in range(10)] + avg = sum(lst) / len(lst) + for i in range(10): + if avg > 5: + avg -= 1 + else: + avg += 1 + return None + + +# 58. Function that creates a list of random floats and discards it +def useless_function_58(): + import random + + lst = [random.uniform(0, 1) for _ in range(10)] + for i in range(10): + if lst[i] > 0.5: + lst[i] = 0 + else: + lst[i] = 1 + return None + + +# 59. Function that generates a random integer and does nothing with it +def useless_function_59(): + import random + + num = random.randint(1, 100) + for i in range(10): + if num % 2 == 0: + num += 1 + else: + num -= 1 + return None + + +# 60. Function that creates a list of dictionaries and discards it +def useless_function_60(): + lst = [{i: i * 2} for i in range(10)] + for i in range(10): + if i % 3 == 0: + lst[i] = {} + else: + lst[i] = {0: 0} + return None + + +# 61. Function that calculates the sum of squares but doesn't return it +def useless_function_61(): + total = sum(i**2 for i in range(10)) + for i in range(10): + if total > 100: + total = 0 + else: + total += 1 + return None + + +# 62. Function that creates a list of sets and discards it +def useless_function_62(): + lst = [set(range(i)) for i in range(10)] + for i in range(10): + if len(lst[i]) > 3: + lst[i] = set() + else: + lst[i] = {0} + return None + + +# 63. Function that generates a random string and does nothing with it +def useless_function_63(): + import random + import string + + s = "".join(random.choice(string.ascii_letters) for _ in range(10)) + for i in range(10): + if s[i] == "a": + s = s.upper() + else: + s = s.lower() + return None + + +# 64. Function that creates a list of tuples and discards it +def useless_function_64(): + lst = [(i, i * 2) for i in range(10)] + for i in range(10): + if lst[i][0] % 2 == 0: + lst[i] = (0, 0) + else: + lst[i] = (1, 1) + return None + + +# 65. Function that calculates the sum of cubes but doesn't return it +def useless_function_65(): + total = sum(i**3 for i in range(10)) + for i in range(10): + if total > 1000: + total = 0 + else: + total += 1 + return None + + +# 66. Function that creates a list of random booleans and discards it +def useless_function_66(): + import random + + lst = [random.choice([True, False]) for _ in range(10)] + for i in range(10): + if lst[i]: + lst[i] = False + else: + lst[i] = True + return None + + +# 67. Function that generates a random float and does nothing with it +def useless_function_67(): + import random + + num = random.uniform(0, 1) + for i in range(10): + if num > 0.5: + num = 0 + else: + num = 1 + return None + + +# 68. Function that creates a list of lists and discards it +def useless_function_68(): + lst = [[i for i in range(10)] for _ in range(10)] + for i in range(10): + if len(lst[i]) > 5: + lst[i] = [] + else: + lst[i] = [0] + return None + + +# 69. Function that calculates the average of a list but doesn't return it +def useless_function_69(): + lst = [i for i in range(10)] + avg = sum(lst) / len(lst) + for i in range(10): + if avg > 5: + avg -= 1 + else: + avg += 1 + return None + + +# 70. Function that creates a list of random floats and discards it +def useless_function_70(): + import random + + lst = [random.uniform(0, 1) for _ in range(10)] + for i in range(10): + if lst[i] > 0.5: + lst[i] = 0 + else: + lst[i] = 1 + return None + + +# 71. Function that generates a random integer and does nothing with it +def useless_function_71(): + import random + + num = random.randint(1, 100) + for i in range(10): + if num % 2 == 0: + num += 1 + else: + num -= 1 + return None + + +# 72. Function that creates a list of dictionaries and discards it +def useless_function_72(): + lst = [{i: i * 2} for i in range(10)] + for i in range(10): + if i % 3 == 0: + lst[i] = {} + else: + lst[i] = {0: 0} + return None + + +# 73. Function that calculates the sum of squares but doesn't return it +def useless_function_73(): + total = sum(i**2 for i in range(10)) + for i in range(10): + if total > 100: + total = 0 + else: + total += 1 + return None + + +# 74. Function that creates a list of sets and discards it +def useless_function_74(): + lst = [set(range(i)) for i in range(10)] + for i in range(10): + if len(lst[i]) > 3: + lst[i] = set() + else: + lst[i] = {0} + return None + + +# 75. Function that generates a random string and does nothing with it +def useless_function_75(): + import random + import string + + s = "".join(random.choice(string.ascii_letters) for _ in range(10)) + for i in range(10): + if s[i] == "a": + s = s.upper() + else: + s = s.lower() + return None + + +# 76. Function that creates a list of tuples and discards it +def useless_function_76(): + lst = [(i, i * 2) for i in range(10)] + for i in range(10): + if lst[i][0] % 2 == 0: + lst[i] = (0, 0) + else: + lst[i] = (1, 1) + return None + + +# 77. Function that calculates the sum of cubes but doesn't return it +def useless_function_77(): + total = sum(i**3 for i in range(10)) + for i in range(10): + if total > 1000: + total = 0 + else: + total += 1 + return None + + +# 78. Function that creates a list of random booleans and discards it +def useless_function_78(): + import random + + lst = [random.choice([True, False]) for _ in range(10)] + for i in range(10): + if lst[i]: + lst[i] = False + else: + lst[i] = True + return None + + +# 79. Function that generates a random float and does nothing with it +def useless_function_79(): + import random + + num = random.uniform(0, 1) + for i in range(10): + if num > 0.5: + num = 0 + else: + num = 1 + return None + + +# 80. Function that creates a list of lists and discards it +def useless_function_80(): + lst = [[i for i in range(10)] for _ in range(10)] + for i in range(10): + if len(lst[i]) > 5: + lst[i] = [] + else: + lst[i] = [0] + return None + + +# 81. Function that calculates the average of a list but doesn't return it +def useless_function_81(): + lst = [i for i in range(10)] + avg = sum(lst) / len(lst) + for i in range(10): + if avg > 5: + avg -= 1 + else: + avg += 1 + return None + + +# 82. Function that creates a list of random floats and discards it +def useless_function_82(): + import random + + lst = [random.uniform(0, 1) for _ in range(10)] + for i in range(10): + if lst[i] > 0.5: + lst[i] = 0 + else: + lst[i] = 1 + return None + + +# 83. Function that generates a random integer and does nothing with it +def useless_function_83(): + import random + + num = random.randint(1, 100) + for i in range(10): + if num % 2 == 0: + num += 1 + else: + num -= 1 + return None + + +# 84. Function that creates a list of dictionaries and discards it +def useless_function_84(): + lst = [{i: i * 2} for i in range(10)] + for i in range(10): + if i % 3 == 0: + lst[i] = {} + else: + lst[i] = {0: 0} + return None + + +# 85. Function that calculates the sum of squares but doesn't return it +def useless_function_85(): + total = sum(i**2 for i in range(10)) + for i in range(10): + if total > 100: + total = 0 + else: + total += 1 + return None + + +# 86. Function that creates a list of sets and discards it +def useless_function_86(): + lst = [set(range(i)) for i in range(10)] + for i in range(10): + if len(lst[i]) > 3: + lst[i] = set() + else: + lst[i] = {0} + return None + + +# 87. Function that generates a random string and does nothing with it +def useless_function_87(): + import random + import string + + s = "".join(random.choice(string.ascii_letters) for _ in range(10)) + for i in range(10): + if s[i] == "a": + s = s.upper() + else: + s = s.lower() + return None + + +# 88. Function that creates a list of tuples and discards it +def useless_function_88(): + lst = [(i, i * 2) for i in range(10)] + for i in range(10): + if lst[i][0] % 2 == 0: + lst[i] = (0, 0) + else: + lst[i] = (1, 1) + return None + + +# 89. Function that calculates the sum of cubes but doesn't return it +def useless_function_89(): + total = sum(i**3 for i in range(10)) + for i in range(10): + if total > 1000: + total = 0 + else: + total += 1 + return None + + +# 90. Function that creates a list of random booleans and discards it +def useless_function_90(): + import random + + lst = [random.choice([True, False]) for _ in range(10)] + for i in range(10): + if lst[i]: + lst[i] = False + else: + lst[i] = True + return None + + +# 91. Function that generates a random float and does nothing with it +def useless_function_91(): + import random + + num = random.uniform(0, 1) + for i in range(10): + if num > 0.5: + num = 0 + else: + num = 1 + return None + + +# 92. Function that creates a list of lists and discards it +def useless_function_92(): + lst = [[i for i in range(10)] for _ in range(10)] + for i in range(10): + if len(lst[i]) > 5: + lst[i] = [] + else: + lst[i] = [0] + return None + + +# 93. Function that calculates the average of a list but doesn't return it +def useless_function_93(): + lst = [i for i in range(10)] + avg = sum(lst) / len(lst) + for i in range(10): + if avg > 5: + avg -= 1 + else: + avg += 1 + return None + + +# 94. Function that creates a list of random floats and discards it +def useless_function_94(): + import random + + lst = [random.uniform(0, 1) for _ in range(10)] + for i in range(10): + if lst[i] > 0.5: + lst[i] = 0 + else: + lst[i] = 1 + return None + + +# 95. Function that generates a random integer and does nothing with it +def useless_function_95(): + import random + + num = random.randint(1, 100) + for i in range(10): + if num % 2 == 0: + num += 1 + else: + num -= 1 + return None + + +# 96. Function that creates a list of dictionaries and discards it +def useless_function_96(): + lst = [{i: i * 2} for i in range(10)] + for i in range(10): + if i % 3 == 0: + lst[i] = {} + else: + lst[i] = {0: 0} + return None + + +# 97. Function that calculates the sum of squares but doesn't return it +def useless_function_97(): + total = sum(i**2 for i in range(10)) + for i in range(10): + if total > 100: + total = 0 + else: + total += 1 + return None + + +# 98. Function that creates a list of sets and discards it +def useless_function_98(): + lst = [set(range(i)) for i in range(10)] + for i in range(10): + if len(lst[i]) > 3: + lst[i] = set() + else: + lst[i] = {0} + return None + + +# 99. Function that generates a random string and does nothing with it +def useless_function_99(): + import random + import string + + s = "".join(random.choice(string.ascii_letters) for _ in range(10)) + for i in range(10): + if s[i] == "a": + s = s.upper() + else: + s = s.lower() + return None + + +# 100. Function that creates a list of tuples and discards it +def useless_function_100(): + lst = [(i, i * 2) for i in range(10)] + for i in range(10): + if lst[i][0] % 2 == 0: + lst[i] = (0, 0) + else: + lst[i] = (1, 1) + return None + + +# 101. Function that generates a random number and performs useless operations +def useless_function_101(): + import random + + num = random.randint(1, 100) + for i in range(15): + num += i + if num % 2 == 0: + num -= random.randint(1, 10) + else: + num += random.randint(1, 10) + if num > 100: + num = 0 + elif num < 0: + num = 100 + return None + + +# 103. Function that calculates the sum of a range but does nothing with it +def useless_function_103(): + total = sum(range(15)) + for i in range(15): + if total > 20: + total -= i + else: + total += i + if total > 100: + total = 0 + return None + + +# 104. Function that creates a list of tuples and discards it +def useless_function_104(): + lst = [(i, i * 2) for i in range(15)] + for i in range(15): + if lst[i][0] % 2 == 0: + lst[i] = (0, 0) + else: + lst[i] = (1, 1) + if i % 4 == 0: + lst[i] = (i, i) + return None + + +# 105. Function that generates a random float and does nothing with it +def useless_function_105(): + import random + + num = random.uniform(0, 1) + for i in range(15): + if num > 0.5: + num = 0 + else: + num = 1 + if i % 5 == 0: + num = random.uniform(0, 1) + return None + + +# 106. Function that creates a list of lists and discards it +def useless_function_106(): + lst = [[i for i in range(15)] for _ in range(15)] + for i in range(15): + if len(lst[i]) > 5: + lst[i] = [] + else: + lst[i] = [0] + if i % 3 == 0: + lst[i] = [i] + return None + + +# 107. Function that calculates the average of a list but doesn't return it +def useless_function_107(): + lst = [i for i in range(15)] + avg = sum(lst) / len(lst) + for i in range(15): + if avg > 5: + avg -= 1 + else: + avg += 1 + if avg > 10: + avg = 0 + return None + + +# 108. Function that creates a list of random floats and discards it +def useless_function_108(): + import random + + lst = [random.uniform(0, 1) for _ in range(15)] + for i in range(15): + if lst[i] > 0.5: + lst[i] = 0 + else: + lst[i] = 1 + if i % 4 == 0: + lst[i] = random.uniform(0, 1) + return None + + +# 109. Function that generates a random integer and does nothing with it +def useless_function_109(): + import random + + num = random.randint(1, 100) + for i in range(15): + if num % 2 == 0: + num += 1 + else: + num -= 1 + if num > 100: + num = 0 + return None + + +# 110. Function that creates a list of dictionaries and discards it +def useless_function_110(): + lst = [{i: i * 2} for i in range(15)] + for i in range(15): + if i % 3 == 0: + lst[i] = {} + else: + lst[i] = {0: 0} + if i % 5 == 0: + lst[i] = {i: i} + return None + + +# 111. Function that calculates the sum of squares but doesn't return it +def useless_function_111(): + total = sum(i**2 for i in range(15)) + for i in range(15): + if total > 100: + total = 0 + else: + total += 1 + if total > 200: + total = 100 + return None + + +# 112. Function that creates a list of sets and discards it +def useless_function_112(): + lst = [set(range(i)) for i in range(15)] + for i in range(15): + if len(lst[i]) > 3: + lst[i] = set() + else: + lst[i] = {0} + if i % 4 == 0: + lst[i] = {i} + return None + + +# 113. Function that generates a random string and does nothing with it +def useless_function_113(): + import random + import string + + s = "".join(random.choice(string.ascii_letters) for _ in range(15)) + for i in range(15): + if s[i] == "a": + s = s.upper() + else: + s = s.lower() + if i % 5 == 0: + s = s[::-1] + return None + + +# 114. Function that creates a list of tuples and discards it +def useless_function_114(): + lst = [(i, i * 2) for i in range(15)] + for i in range(15): + if lst[i][0] % 2 == 0: + lst[i] = (0, 0) + else: + lst[i] = (1, 1) + if i % 3 == 0: + lst[i] = (i, i) + return None + + +# 115. Function that calculates the sum of cubes but doesn't return it +def useless_function_115(): + total = sum(i**3 for i in range(15)) + for i in range(15): + if total > 1000: + total = 0 + else: + total += 1 + if total > 2000: + total = 1000 + return None + + +# 116. Function that creates a list of random booleans and discards it +def useless_function_116(): + import random + + lst = [random.choice([True, False]) for _ in range(15)] + for i in range(15): + if lst[i]: + lst[i] = False + else: + lst[i] = True + if i % 4 == 0: + lst[i] = not lst[i] + return None + + +# 117. Function that generates a random float and does nothing with it +def useless_function_117(): + import random + + num = random.uniform(0, 1) + for i in range(15): + if num > 0.5: + num = 0 + else: + num = 1 + if i % 5 == 0: + num = random.uniform(0, 1) + return None + + +# 118. Function that creates a list of lists and discards it +def useless_function_118(): + lst = [[i for i in range(15)] for _ in range(15)] + for i in range(15): + if len(lst[i]) > 5: + lst[i] = [] + else: + lst[i] = [0] + if i % 3 == 0: + lst[i] = [i] + return None + + +# 119. Function that calculates the average of a list but doesn't return it +def useless_function_119(): + lst = [i for i in range(15)] + avg = sum(lst) / len(lst) + for i in range(15): + if avg > 5: + avg -= 1 + else: + avg += 1 + if avg > 10: + avg = 0 + return None + + +# 120. Function that creates a list of random floats and discards it +def useless_function_120(): + import random + + lst = [random.uniform(0, 1) for _ in range(15)] + for i in range(15): + if lst[i] > 0.5: + lst[i] = 0 + else: + lst[i] = 1 + if i % 4 == 0: + lst[i] = random.uniform(0, 1) + return None + + +# 121. Function that generates a random integer and does nothing with it +def useless_function_121(): + import random + + num = random.randint(1, 100) + for i in range(15): + if num % 2 == 0: + num += 1 + else: + num -= 1 + if num > 100: + num = 0 + return None + + +# 122. Function that creates a list of dictionaries and discards it +def useless_function_122(): + lst = [{i: i * 2} for i in range(15)] + for i in range(15): + if i % 3 == 0: + lst[i] = {} + else: + lst[i] = {0: 0} + if i % 5 == 0: + lst[i] = {i: i} + return None + + +# 123. Function that calculates the sum of squares but doesn't return it +def useless_function_123(): + total = sum(i**2 for i in range(15)) + for i in range(15): + if total > 100: + total = 0 + else: + total += 1 + if total > 200: + total = 100 + return None + + +# 124. Function that creates a list of sets and discards it +def useless_function_124(): + lst = [set(range(i)) for i in range(15)] + for i in range(15): + if len(lst[i]) > 3: + lst[i] = set() + else: + lst[i] = {0} + if i % 4 == 0: + lst[i] = {i} + return None + + +# 126. Function that creates a list of tuples and discards it +def useless_function_126(): + lst = [(i, i * 2) for i in range(15)] + for i in range(15): + if lst[i][0] % 2 == 0: + lst[i] = (0, 0) + else: + lst[i] = (1, 1) + if i % 3 == 0: + lst[i] = (i, i) + return None + + +# 127. Function that calculates the sum of cubes but doesn't return it +def useless_function_127(): + total = sum(i**3 for i in range(15)) + for i in range(15): + if total > 1000: + total = 0 + else: + total += 1 + if total > 2000: + total = 1000 + return None + + +# 128. Function that creates a list of random booleans and discards it +def useless_function_128(): + import random + + lst = [random.choice([True, False]) for _ in range(15)] + for i in range(15): + if lst[i]: + lst[i] = False + else: + lst[i] = True + if i % 4 == 0: + lst[i] = not lst[i] + return None + + +# 129. Function that generates a random float and does nothing with it +def useless_function_129(): + import random + + num = random.uniform(0, 1) + for i in range(15): + if num > 0.5: + num = 0 + else: + num = 1 + if i % 5 == 0: + num = random.uniform(0, 1) + return None + + +# 130. Function that creates a list of lists and discards it +def useless_function_130(): + lst = [[i for i in range(15)] for _ in range(15)] + for i in range(15): + if len(lst[i]) > 5: + lst[i] = [] + else: + lst[i] = [0] + if i % 3 == 0: + lst[i] = [i] + return None + + +# 143. Function to count the frequency of each character in a string +def character_frequency(s): + frequency = {} + for char in s: + if char in frequency: + frequency[char] += 1 + else: + frequency[char] = 1 + return frequency + + +# 144. Function to check if a number is a perfect square +def is_perfect_square(n): + if n < 0: + return False + sqrt = int(n**0.5) + return sqrt * sqrt == n + + +# 145. Function to check if a number is a perfect cube +def is_perfect_cube(n): + if n < 0: + return False + cube_root = round(n ** (1 / 3)) + return cube_root**3 == n + + +# 146. Function to calculate the sum of squares of the first n natural numbers +def sum_of_squares(n): + return sum(i**2 for i in range(1, n + 1)) + + +# 147. Function to calculate the sum of cubes of the first n natural numbers +def sum_of_cubes(n): + return sum(i**3 for i in range(1, n + 1)) + + +# 148. Function to calculate the sum of the digits of a number +def sum_of_digits(n): + total = 0 + while n > 0: + total += n % 10 + n = n // 10 + return total + + +# 149. Function to calculate the product of the digits of a number +def product_of_digits(n): + product = 1 + while n > 0: + product *= n % 10 + n = n // 10 + return product + + +# 150. Function to reverse a number +def reverse_number(n): + reversed_num = 0 + while n > 0: + reversed_num = reversed_num * 10 + n % 10 + n = n // 10 + return reversed_num + + +# 151. Function to check if a number is a palindrome +def is_number_palindrome(n): + return n == reverse_number(n) + + +# 152. Function to generate a list of all divisors of a number +def divisors(n): + divisors = [] + for i in range(1, n + 1): + if n % i == 0: + divisors.append(i) + return divisors + + +# 153. Function to check if a number is abundant +def is_abundant(n): + return sum(divisors(n)) - n > n + + +# 154. Function to check if a number is deficient +def is_deficient(n): + return sum(divisors(n)) - n < n + + +# 155. Function to check if a number is perfect +def is_perfect(n): + return sum(divisors(n)) - n == n + + +# 156. Function to calculate the greatest common divisor (GCD) of two numbers +def gcd(a, b): + while b: + a, b = b, a % b + return a + + +# 157. Function to calculate the least common multiple (LCM) of two numbers +def lcm(a, b): + return a * b // gcd(a, b) + + +# 158. Function to generate a list of the first n triangular numbers +def triangular_numbers(n): + return [i * (i + 1) // 2 for i in range(1, n + 1)] + + +# 159. Function to generate a list of the first n square numbers +def square_numbers(n): + return [i**2 for i in range(1, n + 1)] + + +# 160. Function to generate a list of the first n cube numbers +def cube_numbers(n): + return [i**3 for i in range(1, n + 1)] + + +# 161. Function to calculate the area of a triangle given its base and height +def triangle_area(base, height): + return 0.5 * base * height + + +# 162. Function to calculate the area of a trapezoid given its bases and height +def trapezoid_area(base1, base2, height): + return 0.5 * (base1 + base2) * height + + +# 163. Function to calculate the area of a parallelogram given its base and height +def parallelogram_area(base, height): + return base * height + + +# 164. Function to calculate the area of a rhombus given its diagonals +def rhombus_area(diagonal1, diagonal2): + return 0.5 * diagonal1 * diagonal2 + + +# 165. Function to calculate the area of a regular polygon given the number of sides and side length +def regular_polygon_area(n, side_length): + import math + + return (n * side_length**2) / (4 * math.tan(math.pi / n)) + + +# 166. Function to calculate the perimeter of a regular polygon given the number of sides and side length +def regular_polygon_perimeter(n, side_length): + return n * side_length + + +# 167. Function to calculate the volume of a rectangular prism given its dimensions +def rectangular_prism_volume(length, width, height): + return length * width * height + + +# 168. Function to calculate the surface area of a rectangular prism given its dimensions +def rectangular_prism_surface_area(length, width, height): + return 2 * (length * width + width * height + height * length) + + +# 169. Function to calculate the volume of a pyramid given its base area and height +def pyramid_volume(base_area, height): + return (1 / 3) * base_area * height + + +# 170. Function to calculate the surface area of a pyramid given its base area and slant height +def pyramid_surface_area(base_area, slant_height): + return base_area + (1 / 2) * base_area * slant_height + + +# 171. Function to calculate the volume of a cone given its radius and height +def cone_volume(radius, height): + return (1 / 3) * 3.14159 * radius**2 * height + + +# 172. Function to calculate the surface area of a cone given its radius and slant height +def cone_surface_area(radius, slant_height): + return 3.14159 * radius * (radius + slant_height) + + +# 173. Function to calculate the volume of a sphere given its radius +def sphere_volume(radius): + return (4 / 3) * 3.14159 * radius**3 + + +# 174. Function to calculate the surface area of a sphere given its radius +def sphere_surface_area(radius): + return 4 * 3.14159 * radius**2 + + +# 175. Function to calculate the volume of a cylinder given its radius and height +def cylinder_volume(radius, height): + return 3.14159 * radius**2 * height + + +# 176. Function to calculate the surface area of a cylinder given its radius and height +def cylinder_surface_area(radius, height): + return 2 * 3.14159 * radius * (radius + height) + + +# 177. Function to calculate the volume of a torus given its major and minor radii +def torus_volume(major_radius, minor_radius): + return 2 * 3.14159**2 * major_radius * minor_radius**2 + + +# 178. Function to calculate the surface area of a torus given its major and minor radii +def torus_surface_area(major_radius, minor_radius): + return 4 * 3.14159**2 * major_radius * minor_radius + + +# 179. Function to calculate the volume of an ellipsoid given its semi-axes +def ellipsoid_volume(a, b, c): + return (4 / 3) * 3.14159 * a * b * c + + +# 180. Function to calculate the surface area of an ellipsoid given its semi-axes +def ellipsoid_surface_area(a, b, c): + # Approximation for surface area of an ellipsoid + p = 1.6075 + return 4 * 3.14159 * ((a**p * b**p + a**p * c**p + b**p * c**p) / 3) ** (1 / p) + + +# 181. Function to calculate the volume of a paraboloid given its radius and height +def paraboloid_volume(radius, height): + return (1 / 2) * 3.14159 * radius**2 * height + + +# 182. Function to calculate the surface area of a paraboloid given its radius and height +def paraboloid_surface_area(radius, height): + # Approximation for surface area of a paraboloid + return (3.14159 * radius / (6 * height**2)) * ( + (radius**2 + 4 * height**2) ** (3 / 2) - radius**3 + ) + + +if __name__ == "__main__": + print("Math Helper Library Loaded") diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..10837a56 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,13 @@ +from pathlib import Path +import pytest + + +# ===== FIXTURES ====================== +@pytest.fixture(scope="session") +def output_dir(tmp_path_factory) -> Path: + return tmp_path_factory.mktemp("output") + + +@pytest.fixture(scope="session") +def source_files(tmp_path_factory) -> Path: + return tmp_path_factory.mktemp("input") diff --git a/tests/controllers/test_analyzer_controller.py b/tests/controllers/test_analyzer_controller.py new file mode 100644 index 00000000..e2d782dc --- /dev/null +++ b/tests/controllers/test_analyzer_controller.py @@ -0,0 +1,184 @@ +import textwrap +import pytest +from unittest.mock import Mock +from ecooptimizer.analyzers.analyzer_controller import AnalyzerController +from ecooptimizer.analyzers.ast_analyzers.detect_repeated_calls import detect_repeated_calls +from ecooptimizer.data_types.custom_fields import CRCInfo, Occurence +from ecooptimizer.refactorers.concrete.repeated_calls import CacheRepeatedCallsRefactorer +from ecooptimizer.refactorers.concrete.long_element_chain import LongElementChainRefactorer +from ecooptimizer.refactorers.concrete.list_comp_any_all import UseAGeneratorRefactorer +from ecooptimizer.refactorers.concrete.str_concat_in_loop import UseListAccumulationRefactorer +from ecooptimizer.data_types.smell import CRCSmell + + +@pytest.fixture +def mock_logger(mocker): + logger = Mock() + mocker.patch.dict("ecooptimizer.config.CONFIG", {"detectLogger": logger}) + return logger + + +@pytest.fixture +def mock_crc_smell(): + """Create a mock CRC smell object for testing.""" + return CRCSmell( + confidence="MEDIUM", + message="Repeated function call detected (2/2). Consider caching the result: expensive_function(42)", + messageId="CRC001", + module="main", + obj=None, + path="/path/to/test.py", + symbol="cached-repeated-calls", + type="performance", + occurences=[ + Occurence(line=2, endLine=2, column=14, endColumn=36), + Occurence(line=3, endLine=3, column=14, endColumn=36), + ], + additionalInfo=CRCInfo(callString="expensive_function(42)", repetitions=2), + ) + + +def test_run_analysis_detects_crc_smell(mocker, mock_logger, tmp_path): + """Ensures the analyzer correctly detects CRC smells.""" + test_file = tmp_path / "test.py" + test_file.write_text( + textwrap.dedent(""" + def test_case(): + result1 = expensive_function(42) + result2 = expensive_function(42) + """) + ) + + mocker.patch( + "ecooptimizer.utils.smells_registry.retrieve_smell_registry", + return_value={ + "cached-repeated-calls": SmellRecord( + id="CRC001", + enabled=True, + analyzer_method="ast", + checker=detect_repeated_calls, + analyzer_options={"threshold": 2}, + refactorer=CacheRepeatedCallsRefactorer, + ) + }, + ) + + controller = AnalyzerController() + smells = controller.run_analysis(test_file) + + print("Detected smells:", smells) + assert len(smells) == 1 + assert isinstance(smells[0], CRCSmell) + assert smells[0].additionalInfo.callString == "expensive_function(42)" + mock_logger.info.assert_any_call("⚠️ Detected Code Smells:") + + +def test_run_analysis_no_crc_smells_detected(mocker, mock_logger, tmp_path): + """Ensures the analyzer logs properly when no CRC smells are found.""" + test_file = tmp_path / "test.py" + test_file.write_text("print('No smells here')") + + mocker.patch( + "ecooptimizer.utils.smells_registry.retrieve_smell_registry", + return_value={ + "cached-repeated-calls": SmellRecord( + id="CRC001", + enabled=True, + analyzer_method="ast", + checker=detect_repeated_calls, + analyzer_options={"threshold": 2}, + refactorer=CacheRepeatedCallsRefactorer, + ) + }, + ) + + controller = AnalyzerController() + smells = controller.run_analysis(test_file) + + assert smells == [] + mock_logger.info.assert_called_with("πŸŽ‰ No code smells detected.") + + +from ecooptimizer.data_types.smell_record import SmellRecord + + +def test_filter_smells_by_method(): + """Ensures the method filters all types of smells correctly.""" + mock_registry = { + "cached-repeated-calls": SmellRecord( + id="CRC001", + enabled=True, + analyzer_method="ast", + checker=lambda x: x, + analyzer_options={}, + refactorer=CacheRepeatedCallsRefactorer, + ), + "long-element-chain": SmellRecord( + id="LEC001", + enabled=True, + analyzer_method="ast", + checker=lambda x: x, + analyzer_options={}, + refactorer=LongElementChainRefactorer, + ), + "use-a-generator": SmellRecord( + id="R1729", + enabled=True, + analyzer_method="pylint", + checker=None, + analyzer_options={}, + refactorer=UseAGeneratorRefactorer, + ), + "string-concat-loop": SmellRecord( + id="SCL001", + enabled=True, + analyzer_method="astroid", + checker=lambda x: x, + analyzer_options={}, + refactorer=UseListAccumulationRefactorer, + ), + } + + result_ast = AnalyzerController.filter_smells_by_method(mock_registry, "ast") + result_pylint = AnalyzerController.filter_smells_by_method(mock_registry, "pylint") + result_astroid = AnalyzerController.filter_smells_by_method(mock_registry, "astroid") + + assert "cached-repeated-calls" in result_ast + assert "long-element-chain" in result_ast + assert "use-a-generator" in result_pylint + assert "string-concat-loop" in result_astroid + + +def test_generate_custom_options(): + """Ensures AST and Astroid analysis options are generated correctly.""" + mock_registry = { + "cached-repeated-calls": SmellRecord( + id="CRC001", + enabled=True, + analyzer_method="ast", + checker=lambda x: x, + analyzer_options={}, + refactorer=CacheRepeatedCallsRefactorer, + ), + "long-element-chain": SmellRecord( + id="LEC001", + enabled=True, + analyzer_method="ast", + checker=lambda x: x, + analyzer_options={}, + refactorer=LongElementChainRefactorer, + ), + "string-concat-loop": SmellRecord( + id="SCL001", + enabled=True, + analyzer_method="astroid", + checker=lambda x: x, + analyzer_options={}, + refactorer=UseListAccumulationRefactorer, + ), + } + options = AnalyzerController.generate_custom_options(mock_registry) + assert len(options) == 3 + assert callable(options[0][0]) + assert callable(options[1][0]) + assert callable(options[2][0]) diff --git a/tests/controllers/test_refactorer_controller.py b/tests/controllers/test_refactorer_controller.py new file mode 100644 index 00000000..9d8222e8 --- /dev/null +++ b/tests/controllers/test_refactorer_controller.py @@ -0,0 +1,147 @@ +from unittest.mock import Mock +import pytest + +from ecooptimizer.data_types.custom_fields import Occurence +from ecooptimizer.refactorers.refactorer_controller import RefactorerController +from ecooptimizer.data_types.smell import LECSmell + + +@pytest.fixture +def mock_refactorer_class(mocker): + mock_class = mocker.Mock() + mock_class.__name__ = "TestRefactorer" + return mock_class + + +@pytest.fixture +def mock_logger(mocker): + logger = Mock() + mocker.patch.dict("ecooptimizer.config.CONFIG", {"refactorLogger": logger}) + return logger + + +@pytest.fixture +def mock_smell(): + """Create a mock smell object for testing.""" + return LECSmell( + confidence="UNDEFINED", + message="Dictionary chain too long (6/4)", + messageId="LEC001", + module="lec_module", + obj="lec_function", + path="path/to/file.py", + symbol="long-element-chain", + type="convention", + occurences=[Occurence(line=10, endLine=10, column=15, endColumn=26)], + additionalInfo=None, + ) + + +def test_run_refactorer_success(mocker, mock_refactorer_class, mock_logger, tmp_path, mock_smell): + # Setup mock refactorer + mock_instance = mock_refactorer_class.return_value + # mock_instance.refactor = Mock() + mock_refactorer_class.return_value = mock_instance + + mock_instance.modified_files = [tmp_path / "modified.py"] + + mocker.patch( + "ecooptimizer.refactorers.refactorer_controller.get_refactorer", + return_value=mock_refactorer_class, + ) + + controller = RefactorerController() + target_file = tmp_path / "test.py" + target_file.write_text("print('test content')") # 🚨 Create file with dummy content + + source_dir = tmp_path + + # Execute + modified_files = controller.run_refactorer(target_file, source_dir, mock_smell) + + # Assertions + assert controller.smell_counters["LEC001"] == 1 + mock_logger.info.assert_called_once_with( + "πŸ”„ Running refactoring for long-element-chain using TestRefactorer" + ) + mock_instance.refactor.assert_called_once_with( + target_file, source_dir, mock_smell, mocker.ANY, True + ) + call_args = mock_instance.refactor.call_args + output_path = call_args[0][3] + assert output_path.name == "test_path_LEC001_1.py" + assert modified_files == [tmp_path / "modified.py"] + + +def test_run_refactorer_no_refactorer(mock_logger, mocker, tmp_path, mock_smell): + mocker.patch("ecooptimizer.refactorers.refactorer_controller.get_refactorer", return_value=None) + controller = RefactorerController() + target_file = tmp_path / "test.py" + source_dir = tmp_path + + with pytest.raises(NotImplementedError) as exc_info: + controller.run_refactorer(target_file, source_dir, mock_smell) + + mock_logger.error.assert_called_once_with( + "❌ No refactorer found for smell: long-element-chain" + ) + assert "No refactorer implemented for smell: long-element-chain" in str(exc_info.value) + + +def test_run_refactorer_multiple_calls(mocker, mock_refactorer_class, tmp_path, mock_smell): + mock_instance = mock_refactorer_class.return_value + mock_instance.modified_files = [] + mocker.patch( + "ecooptimizer.refactorers.refactorer_controller.get_refactorer", + return_value=mock_refactorer_class, + ) + mocker.patch.dict("ecooptimizer.config.CONFIG", {"refactorLogger": Mock()}) + + controller = RefactorerController() + target_file = tmp_path / "test.py" + source_dir = tmp_path + smell = mock_smell + + controller.run_refactorer(target_file, source_dir, smell) + controller.run_refactorer(target_file, source_dir, smell) + + assert controller.smell_counters["LEC001"] == 2 + calls = mock_instance.refactor.call_args_list + assert calls[0][0][3].name == "test_path_LEC001_1.py" + assert calls[1][0][3].name == "test_path_LEC001_2.py" + + +def test_run_refactorer_overwrite_false(mocker, mock_refactorer_class, tmp_path, mock_smell): + mock_instance = mock_refactorer_class.return_value + mocker.patch( + "ecooptimizer.refactorers.refactorer_controller.get_refactorer", + return_value=mock_refactorer_class, + ) + mocker.patch.dict("ecooptimizer.config.CONFIG", {"refactorLogger": Mock()}) + + controller = RefactorerController() + target_file = tmp_path / "test.py" + source_dir = tmp_path + smell = mock_smell + + controller.run_refactorer(target_file, source_dir, smell, overwrite=False) + call_args = mock_instance.refactor.call_args + assert call_args[0][4] is False # overwrite is the fifth argument + + +def test_run_refactorer_empty_modified_files(mocker, mock_refactorer_class, tmp_path, mock_smell): + mock_instance = mock_refactorer_class.return_value + mock_instance.modified_files = [] + mocker.patch( + "ecooptimizer.refactorers.refactorer_controller.get_refactorer", + return_value=mock_refactorer_class, + ) + mocker.patch.dict("ecooptimizer.config.CONFIG", {"refactorLogger": Mock()}) + + controller = RefactorerController() + target_file = tmp_path / "test.py" + source_dir = tmp_path + smell = mock_smell + + modified_files = controller.run_refactorer(target_file, source_dir, smell) + assert modified_files == [] diff --git a/tests/input/__init__.py b/tests/input/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/input/inefficient_code_example_1.py b/tests/input/inefficient_code_example_1.py new file mode 100644 index 00000000..dae6717c --- /dev/null +++ b/tests/input/inefficient_code_example_1.py @@ -0,0 +1,33 @@ +# Should trigger Use A Generator code smells + +def has_positive(numbers): + # List comprehension inside `any()` - triggers R1729 + return any([num > 0 for num in numbers]) + +def all_non_negative(numbers): + # List comprehension inside `all()` - triggers R1729 + return all([num >= 0 for num in numbers]) + +def contains_large_strings(strings): + # List comprehension inside `any()` - triggers R1729 + return any([len(s) > 10 for s in strings]) + +def all_uppercase(strings): + # List comprehension inside `all()` - triggers R1729 + return all([s.isupper() for s in strings]) + +def contains_special_numbers(numbers): + # List comprehension inside `any()` - triggers R1729 + return any([num % 5 == 0 and num > 100 for num in numbers]) + +def all_lowercase(strings): + # List comprehension inside `all()` - triggers R1729 + return all(s.islower() for s in strings) + +def any_even_numbers(numbers): + # List comprehension inside `any()` - triggers R1729 + return any(num % 2 == 0 for num in numbers) + +def all_strings_start_with_a(strings): + # List comprehension inside `all()` - triggers R1729 + return all(s.startswith('A') for s in strings) diff --git a/tests/input/inefficient_code_example_2.py b/tests/input/inefficient_code_example_2.py new file mode 100644 index 00000000..f68c1f09 --- /dev/null +++ b/tests/input/inefficient_code_example_2.py @@ -0,0 +1,119 @@ +import datetime # unused import + + +class Temp: + + def __init__(self) -> None: + self.unused_class_attribute = True + self.a = 3 + + def temp_function(self): + unused_var = 3 + b = 4 + return self.a + b + + +class DataProcessor: + + def __init__(self, data): + self.data = data + self.processed_data = [] + + def process_all_data(self): + if not self.data: + return [] + results = [] + for item in self.data: + try: + result = self.complex_calculation(item, "multiply", True, False) + results.append(result) + except Exception as e: + print("An error occurred:", e) + if isinstance(self.data[0], str): + print(self.data[0].upper().strip().replace(" ", "_").lower()) + self.processed_data = list( + filter(lambda x: x is not None and x != 0 and len(str(x)) > 1, results) + ) + return self.processed_data + + @staticmethod + def complex_calculation(item, operation, threshold, max_value): + if operation == "multiply": + result = item * threshold + elif operation == "add": + result = item + max_value + else: + result = item + return result + + @staticmethod + def multi_param_calculation( + item1, + item2, + item3, + flag1, + flag2, + flag3, + operation, + threshold, + max_value, + option, + final_stage, + min_value, + ): + value = 0 + if operation == "multiply": + value = item1 * item2 * item3 + elif operation == "add": + value = item1 + item2 + item3 + elif flag1 == "true": + value = item1 + elif flag2 == "true": + value = item2 + elif flag3 == "true": + value = item3 + elif max_value < threshold: + value = max_value + else: + value = min_value + return value + + +class AdvancedProcessor(DataProcessor): + + @staticmethod + def check_data(item): + return ( + True if item > 10 else False if item < -10 else None if item == 0 else item + ) + + def complex_comprehension(self): + self.processed_data = [ + (x**2 if x % 2 == 0 else x**3) + for x in range(1, 100) + if x % 5 == 0 and x != 50 and x > 3 + ] + + def long_chain(self): + try: + deep_value = self.data[0][1]["details"]["info"]["more_info"][2]["target"] + return deep_value + except (KeyError, IndexError, TypeError): + return None + + @staticmethod + def long_scope_chaining(): + for a in range(10): + for b in range(10): + for c in range(10): + for d in range(10): + for e in range(10): + if a + b + c + d + e > 25: + return "Done" + + +if __name__ == "__main__": + sample_data = [1, 2, 3, 4, 5] + processor = DataProcessor(sample_data) + processed = processor.process_all_data() + print("Processed Data:", processed) diff --git a/tests/input/inefficient_code_example_2_tests.py b/tests/input/inefficient_code_example_2_tests.py new file mode 100644 index 00000000..4f0c1731 --- /dev/null +++ b/tests/input/inefficient_code_example_2_tests.py @@ -0,0 +1,105 @@ +import unittest +from datetime import datetime + +from inefficient_code_example_2 import ( + AdvancedProcessor, + DataProcessor, +) # Just to show the unused import issue + + +# Assuming the classes DataProcessor and AdvancedProcessor are already defined +# and imported + + +class TestDataProcessor(unittest.TestCase): + + def test_process_all_data(self): + # Test valid data processing + data = [1, 2, 3, 4, 5] + processor = DataProcessor(data) + processed_data = processor.process_all_data() + # Expecting values [10, 20, 30, 40, 50] (because all are greater than 1 character in length) + self.assertEqual(processed_data, [10, 20, 30, 40, 50]) + + def test_process_all_data_empty(self): + # Test with empty data list + processor = DataProcessor([]) + processed_data = processor.process_all_data() + self.assertEqual(processed_data, []) + + def test_complex_calculation_multiply(self): + # Test multiplication operation + result = DataProcessor.complex_calculation(True, "multiply", 10, 20) + self.assertEqual(result, 50) # 5 * 10 + + def test_complex_calculation_add(self): + # Test addition operation + result = DataProcessor.complex_calculation(True, "add", 20, 5) + self.assertEqual(result, 25) # 5 + 20 + + def test_complex_calculation_default(self): + # Test default operation + result = DataProcessor.complex_calculation(True, "unknown", 10, 20) + self.assertEqual(result, 5) # Default value is item itself + + +class TestAdvancedProcessor(unittest.TestCase): + + def test_complex_comprehension(self): + # Test complex list comprehension + processor = AdvancedProcessor([1, 2, 3, 4, 5]) + processor.complex_comprehension() + expected_result = [ + 125, + 100, + 3375, + 400, + 15625, + 900, + 42875, + 1600, + 91125, + 166375, + 3600, + 274625, + 4900, + 421875, + 6400, + 614125, + 8100, + 857375, + ] + self.assertEqual(processor.processed_data, expected_result) + + def test_long_chain_valid(self): + # Test valid deep chain access + data = [ + [ + None, + { + "details": { + "info": {"more_info": [{}, {}, {"target": "Valid Value"}]} + } + }, + ] + ] + processor = AdvancedProcessor(data) + result = processor.long_chain() + self.assertEqual(result, "Valid Value") + + def test_long_chain_invalid(self): + # Test invalid deep chain access, should return None + data = [{"details": {"info": {"more_info": [{}]}}}] + processor = AdvancedProcessor(data) + result = processor.long_chain() + self.assertIsNone(result) + + def test_long_scope_chaining(self): + # Test long scope chaining, expecting 'Done' when the sum exceeds 25 + processor = AdvancedProcessor([1, 2, 3, 4, 5]) + result = processor.long_scope_chaining() + self.assertEqual(result, "Done") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/input/inefficient_code_example_3.py b/tests/input/inefficient_code_example_3.py new file mode 100644 index 00000000..04cc9573 --- /dev/null +++ b/tests/input/inefficient_code_example_3.py @@ -0,0 +1,22 @@ +import numpy as np +import time + + +def heavy_computation(): + # Start a large matrix multiplication task to consume CPU + print("Starting heavy computation...") + size = 1000 + matrix_a = np.random.rand(size, size) + matrix_b = np.random.rand(size, size) + + start_time = time.time() + result = np.dot(matrix_a, matrix_b) + end_time = time.time() + + print(f"Heavy computation finished in {end_time - start_time:.2f} seconds") + + +# Run the heavy computation in a loop for a longer duration +for _ in range(5): + heavy_computation() + time.sleep(1) # Add a small delay to observe periodic CPU load diff --git a/tests/input/long_param.py b/tests/input/long_param.py new file mode 100644 index 00000000..a95b0cfa --- /dev/null +++ b/tests/input/long_param.py @@ -0,0 +1,252 @@ +################################################ Constructors ############################################################### +class UserDataProcessor1: + # 1. 0 parameters + def __init__(self): + self.config = {} + self.data = [] + +class UserDataProcessor2: + # 2. 4 parameters (no unused) + def __init__(self, user_id, username, email, app_config): + self.user_id = user_id + self.username = username + self.email = email + self.app_config = app_config + +class UserDataProcessor3: + # 3. 4 parameters (1 unused) + def __init__(self, user_id, username, email, theme="light"): + self.user_id = user_id + self.username = username + self.email = email + # theme is unused + +class UserDataProcessor4: + # 4. 8 parameters (no unused) + def __init__(self, user_id, username, email, preferences, timezone_config, language, notification_settings, is_active): + self.user_id = user_id + self.username = username + self.email = email + self.preferences = preferences + self.timezone_config = timezone_config + self.language = language + self.notification_settings = notification_settings + self.is_active = is_active + +class UserDataProcessor5: + # 5. 8 parameters (1 unused) + def __init__(self, user_id, username, email, preferences, timezone_config, region, notification_settings, theme="light"): + self.user_id = user_id + self.username = username + self.email = email + self.preferences = preferences + self.timezone_config = timezone_config + self.region = region + self.notification_settings = notification_settings + # theme is unused + +class UserDataProcessor6: + # 6. 8 parameters (4 unused) + def __init__(self, user_id, username, email, preferences, timezone_config, backup_config=None, display_theme=None, active_status=None): + self.user_id = user_id + self.username = username + self.email = email + self.preferences = preferences + # timezone_config, backup_config, display_theme, active_status are unused + + ################################################ Instance Methods ############################################################### + # 1. 0 parameters + def clear_data(self): + self.data = [] + + # 2. 4 parameters (no unused) + def update_settings(self, display_mode, alert_settings, language_preference, timezone_config): + self.settings["display_mode"] = display_mode + self.settings["alert_settings"] = alert_settings + self.settings["language_preference"] = language_preference + self.settings["timezone"] = timezone_config + + # 3. 4 parameters (1 unused) + def update_profile(self, username, email, timezone_config, bio=None): + self.username = username + self.email = email + self.settings["timezone"] = timezone_config + # bio is unused + + # 4. 8 parameters (no unused) + def bulk_update(self, username, email, preferences, timezone_config, region, notification_settings, theme="light", is_active=None): + self.username = username + self.email = email + self.preferences = preferences + self.settings["timezone"] = timezone_config + self.settings["region"] = region + self.settings["notifications"] = notification_settings + self.settings["theme"] = theme + self.settings["is_active"] = is_active + + # 5. 8 parameters (1 unused) + def bulk_update_partial(self, username, email, preferences, timezone_config, region, notification_settings, theme, active_status=None): + self.username = username + self.email = email + self.preferences = preferences + self.settings["timezone"] = timezone_config + self.settings["region"] = region + self.settings["notifications"] = notification_settings + self.settings["theme"] = theme + # active_status is unused + + # 6. 7 parameters (3 unused) + def partial_update(self, username, email, preferences, timezone_config, backup_config=None, display_theme=None, active_status=None): + self.username = username + self.email = email + self.preferences = preferences + self.settings["timezone"] = timezone_config + # backup_config, display_theme, active_status are unused + +################################################ Static Methods ############################################################### + + # 1. 0 parameters + @staticmethod + def reset_global_settings(): + return {"theme": "default", "language": "en", "notifications": True} + + # 2. 4 parameters (no unused) + @staticmethod + def validate_user_input(username, email, password, age): + return all([username, email, password, age >= 18]) + + # 3. 4 parameters (2 unused) + @staticmethod + def hash_password(password, salt, encryption="SHA256", retries=1000): + # encryption and retries are unused + return f"hashed({password} + {salt})" + + # 4. 8 parameters (no unused) + @staticmethod + def generate_report(username, email, preferences, timezone_config, region, notification_settings, theme, is_active): + return { + "username": username, + "email": email, + "preferences": preferences, + "timezone": timezone_config, + "region": region, + "notifications": notification_settings, + "theme": theme, + "is_active": is_active, + } + + # 5. 8 parameters (1 unused) + @staticmethod + def generate_report_partial(username, email, preferences, timezone_config, region, notification_settings, theme, active_status=None): + return { + "username": username, + "email": email, + "preferences": preferences, + "timezone": timezone_config, + "region": region, + "notifications": notification_settings, + "active status": active_status, + } + # theme is unused + + # 6. 8 parameters (3 unused) + # @staticmethod + # def minimal_report(username, email, preferences, timezone_config, backup, region="Global", display_mode=None, status=None): + # return { + # "username": username, + # "email": email, + # "preferences": preferences, + # "timezone": timezone_config, + # "region": region + # } + # # backup, display_mode, status are unused + + +################################################ Standalone Functions ############################################################### + +# 1. 0 parameters +def reset_system(): + return "System reset completed" + +# 2. 4 parameters (no unused) +def calculate_discount(price, discount_rate, minimum_purchase, maximum_discount): + if price >= minimum_purchase: + return min(price * discount_rate, maximum_discount) + return 0 + +# 3. 4 parameters (1 unused) +def apply_coupon(coupon_code, expiry_date, discount_rate, minimum_order=None): + return f"Coupon {coupon_code} applied with {discount_rate}% off until {expiry_date}" + # minimum_order is unused + +# 4. 8 parameters (no unused) +def create_user_report(user_id, username, email, preferences, timezone_config, language, notification_settings, is_active): + return { + "user_id": user_id, + "username": username, + "email": email, + "preferences": preferences, + "timezone": timezone_config, + "language": language, + "notifications": notification_settings, + "is_active": is_active, + } + +# 5. 8 parameters (1 unused) +def create_partial_report(user_id, username, email, preferences, timezone_config, language, notification_settings, active_status=None): + return { + "user_id": user_id, + "username": username, + "email": email, + "preferences": preferences, + "timezone": timezone_config, + "language": language, + "notifications": notification_settings, + } + # active_status is unused + +# 6. 8 parameters (3 unused) +def create_minimal_report(user_id, username, email, preferences, timezone_config, backup_config=None, alert_settings=None, active_status=None): + return { + "user_id": user_id, + "username": username, + "email": email, + "preferences": preferences, + "timezone": timezone_config, + } + # backup_config, alert_settings, active_status are unused + +################################################ Calls ############################################################### + +# Constructor calls +user1 = UserDataProcessor1() +user2 = UserDataProcessor2(1, "johndoe", "johndoe@example.com", app_config={"theme": "dark"}) +user3 = UserDataProcessor3(1, "janedoe", email="janedoe@example.com") +user4 = UserDataProcessor4(2, "johndoe", "johndoe@example.com", {"theme": "dark"}, "UTC", language="en", notification_settings=False, is_active=True) +user5 = UserDataProcessor5(2, "janedoe", "janedoe@example.com", {"theme": "light"}, "UTC", region="en", notification_settings=False) +user6 = UserDataProcessor6(3, "janedoe", "janedoe@example.com", {"theme": "blue"}, timezone_config="PST") + +# Instance method calls +user6.clear_data() +user6.update_settings("dark_mode", True, "en", timezone_config="UTC") +user6.update_profile(username="janedoe", email="janedoe@example.com", timezone_config="PST") +user6.bulk_update("johndoe", "johndoe@example.com", {"theme": "dark"}, "UTC", "en", True, "dark", is_active=True) +user6.bulk_update_partial("janedoe", "janedoe@example.com", {"theme": "light"}, "PST", "en", False, "light", active_status="offline") +user6.partial_update("janedoe", "janedoe@example.com", preferences={"theme": "blue"}, timezone_config="PST") + +# Static method calls +UserDataProcessor6.reset_global_settings() +UserDataProcessor6.validate_user_input("johndoe", "johndoe@example.com", password="password123", age=25) +UserDataProcessor6.hash_password("password123", "salt123", retries=200) +UserDataProcessor6.generate_report("johndoe", "johndoe@example.com", {"theme": "dark"}, "UTC", "en", True, "dark", True) +UserDataProcessor6.generate_report_partial("janedoe", "janedoe@example.com", {"theme": "light"}, "PST", "en", False, theme="green", active_status="online") +# UserDataProcessor6.minimal_report("janedoe", "janedoe@example.com", {"theme": "blue"}, "PST", False, "Canada") + +# Standalone function calls +reset_system() +calculate_discount(price=100, discount_rate=0.1, minimum_purchase=50, maximum_discount=20) +apply_coupon("SAVE10", "2025-12-31", 10, minimum_order=2) +create_user_report(1, "johndoe", "johndoe@example.com", {"theme": "dark"}, "UTC", "en", True, True) +create_partial_report(2, "janedoe", "janedoe@example.com", {"theme": "light"}, "PST", "en", notification_settings=False) +create_minimal_report(3, "janedoe", "janedoe@example.com", {"theme": "blue"}, timezone_config="PST") + diff --git a/tests/input/project_car_stuff/__init__.py b/tests/input/project_car_stuff/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/input/project_car_stuff/main.py b/tests/input/project_car_stuff/main.py new file mode 100644 index 00000000..1ae1a0e9 --- /dev/null +++ b/tests/input/project_car_stuff/main.py @@ -0,0 +1,191 @@ +import math # Unused import + + +class Test: + def __init__(self, name) -> None: + self.name = name + pass + + def unused_method(self): + print("Hello World!") + + +# Code Smell: Long Parameter List +class Vehicle: + def __init__( + self, + make, + model, + year: int, + color, + fuel_type, + engine_start_stop_option, + mileage, + suspension_setting, + transmission, + price, + seat_position_setting=None, + ): + # Code Smell: Long Parameter List in __init__ + self.make = make # positional argument + self.model = model + self.year = year + self.color = color + self.fuel_type = fuel_type + self.engine_start_stop_option = engine_start_stop_option + self.mileage = mileage + self.suspension_setting = suspension_setting + self.transmission = transmission + self.price = price + self.seat_position_setting = seat_position_setting # default value + self.owner = None # Unused class attribute, used in constructor + + def display_info(self): + # Code Smell: Long Message Chain + random_test = self.make.split("") + print( + f"Make: {self.make}, Model: {self.model}, Year: {self.year}".upper().replace( + ",", "" + )[ + ::2 + ] + ) + + def calculate_price(self): + # Code Smell: List Comprehension in an All Statement + condition = all( + [ + isinstance(attribute, str) + for attribute in [self.make, self.model, self.year, self.color] + ] + ) + if condition: + return ( + self.price * 0.9 + ) # Apply a 10% discount if all attributes are strings (totally arbitrary condition) + + return self.price + + def unused_method(self): + # Code Smell: Member Ignoring Method + print( + "This method doesn't interact with instance attributes, it just prints a statement." + ) + + +class Car(Vehicle): + + def __init__( + self, + make, + model, + year, + color, + fuel_type, + engine_start_stop_option, + mileage, + suspension_setting, + transmission, + price, + sunroof=False, + ): + super().__init__( + make, + model, + year, + color, + fuel_type, + engine_start_stop_option, + mileage, + suspension_setting, + transmission, + price, + ) + self.sunroof = sunroof + self.engine_size = 2.0 # Unused variable in class + + def add_sunroof(self): + # Code Smell: Long Parameter List + self.sunroof = True + print("Sunroof added!") + + def show_details(self): + # Code Smell: Long Message Chain + details = f"Car: {self.make} {self.model} ({self.year}) | Mileage: {self.mileage} | Transmission: {self.transmission} | Sunroof: {self.sunroof} | Engine Start Option: {self.engine_start_stop_option} | Suspension Setting: {self.suspension_setting} | Seat Position {self.seat_position_setting}" + print(details.upper().lower().upper().capitalize().upper().replace("|", "-")) + + +def process_vehicle(vehicle: Vehicle): + # Code Smell: Unused Variables + temp_discount = 0.05 + temp_shipping = 100 + + vehicle.display_info() + price_after_discount = vehicle.calculate_price() + print(f"Price after discount: {price_after_discount}") + + vehicle.unused_method() # Calls a method that doesn't actually use the class attributes + + +def is_all_string(attributes): + # Code Smell: List Comprehension in an All Statement + return all(isinstance(attribute, str) for attribute in attributes) + + +def access_nested_dict(): + nested_dict1 = {"level1": {"level2": {"level3": {"key": "value"}}}} + + nested_dict2 = { + "level1": { + "level2": { + "level3": {"key": "value", "key2": "value2"}, + "level3a": {"key": "value"}, + } + } + } + print(nested_dict1["level1"]["level2"]["level3"]["key"]) + print(nested_dict2["level1"]["level2"]["level3"]["key2"]) + print(nested_dict2["level1"]["level2"]["level3"]["key"]) + print(nested_dict2["level1"]["level2"]["level3a"]["key"]) + print(nested_dict1["level1"]["level2"]["level3"]["key"]) + + +# Main loop: Arbitrary use of the classes and demonstrating code smells +if __name__ == "__main__": + car1 = Car( + make="Toyota", + model="Camry", + year=2020, + color="Blue", + fuel_type="Gas", + engine_start_stop_option="no key", + mileage=25000, + suspension_setting="Sport", + transmission="Automatic", + price=20000, + ) + process_vehicle(car1) + car1.add_sunroof() + car1.show_details() + + car1.unused_method() + + # Testing with another vehicle object + car2 = Vehicle( + "Honda", + model="Civic", + year=2018, + color="Red", + fuel_type="Gas", + engine_start_stop_option="key", + mileage=30000, + suspension_setting="Sport", + transmission="Manual", + price=15000, + ) + process_vehicle(car2) + + test = Test("Anna") + test.unused_method() + + print("Hello") diff --git a/tests/input/project_car_stuff/test_main.py b/tests/input/project_car_stuff/test_main.py new file mode 100644 index 00000000..70126d34 --- /dev/null +++ b/tests/input/project_car_stuff/test_main.py @@ -0,0 +1,34 @@ +import pytest +from .main import Vehicle, Car, process_vehicle + +# Fixture to create a car instance +@pytest.fixture +def car1(): + return Car(make="Toyota", model="Camry", year=2020, color="Blue", fuel_type="Gas", mileage=25000, transmission="Automatic", price=20000) + +# Test the price after applying discount +def test_vehicle_price_after_discount(car1): + assert car1.calculate_price() == 20000, "Price after discount should be 18000" + +# Test the add_sunroof method to confirm it works as expected +def test_car_add_sunroof(car1): + car1.add_sunroof() + assert car1.sunroof is True, "Car should have sunroof after add_sunroof() is called" + +# Test that show_details method runs without error +def test_car_show_details(car1, capsys): + car1.show_details() + captured = capsys.readouterr() + assert "CAR: TOYOTA CAMRY" in captured.out # Checking if the output contains car details + +# Test the is_all_string function indirectly through the calculate_price method +def test_is_all_string(car1): + price_after_discount = car1.calculate_price() + assert price_after_discount > 0, "Price calculation should return a valid price" + +# Test the process_vehicle function to check its behavior with a Vehicle object +def test_process_vehicle(car1, capsys): + process_vehicle(car1) + captured = capsys.readouterr() + assert "Price after discount" in captured.out, "The process_vehicle function should output the price after discount" + diff --git a/tests/input/project_long_parameter_list/src/__init__.py b/tests/input/project_long_parameter_list/src/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/input/project_long_parameter_list/src/caller_1.py b/tests/input/project_long_parameter_list/src/caller_1.py new file mode 100644 index 00000000..d0409523 --- /dev/null +++ b/tests/input/project_long_parameter_list/src/caller_1.py @@ -0,0 +1,7 @@ +from main import process_data, process_extra + +pd = process_data(1, 2, 3, 4, 3, 2, 3, 5) +pe = process_extra(1, 2, 3, 4, 3, 2, 3, 5) + +print(pd) +print(pe) \ No newline at end of file diff --git a/tests/input/project_long_parameter_list/src/caller_2.py b/tests/input/project_long_parameter_list/src/caller_2.py new file mode 100644 index 00000000..241cf165 --- /dev/null +++ b/tests/input/project_long_parameter_list/src/caller_2.py @@ -0,0 +1,7 @@ +from main import Helper + +pcd = Helper.process_class_data(1, 2, 3, 4, 3, 2, 3, 5) +pmd = Helper.process_more_class_data(1, 2, 3, 4, 3, 2, 3, 5) + +print(pcd) +print(pmd) \ No newline at end of file diff --git a/tests/input/project_long_parameter_list/src/main.py b/tests/input/project_long_parameter_list/src/main.py new file mode 100644 index 00000000..84c3a9bd --- /dev/null +++ b/tests/input/project_long_parameter_list/src/main.py @@ -0,0 +1,44 @@ +import math +print(math.isclose(20, 100)) + +def process_local_call(data_value1, data_value2, data_item1, data_item2, + config_path, config_setting, config_option, config_env): + return (data_value1 * data_value2 - data_item1 * data_item2 + + config_path * config_setting - config_option * config_env) + + +def process_data(data_value1, data_value2, data_item1, data_item2, + config_path, config_setting, config_option, config_env): + return (data_value1 + data_value2 + data_item1) * (data_item2 + config_path + ) - (config_setting + config_option + config_env) + + +def process_extra(data_record1, data_record2, data_result1, data_result2, + config_file, config_mode, config_param, config_directory): + return data_record1 - data_record2 + (data_result1 - data_result2) * ( + config_file - config_mode) + (config_param - config_directory) + + +class Helper: + + def process_class_data(self, data_input1, data_input2, data_output1, + data_output2, config_file, config_user, config_theme, config_env): + return (data_input1 * data_input2 + data_output1 * data_output2 - + config_file * config_user + config_theme * config_env) + + def process_more_class_data(self, data_record1, data_record2, + data_item1, data_item2, config_log, config_cache, config_timeout, + config_profile): + return data_record1 + data_record2 - (data_item1 + data_item2) + ( + config_log + config_cache) - (config_timeout + config_profile) + + +def main(): + local_result = process_local_call(1, 2, 3, 4, 3, 2, 3, 5) + print(local_result) + + +if __name__ == '__main__': + main() + + diff --git a/tests/input/project_long_parameter_list/tests/test_main.py b/tests/input/project_long_parameter_list/tests/test_main.py new file mode 100644 index 00000000..c1d6018e --- /dev/null +++ b/tests/input/project_long_parameter_list/tests/test_main.py @@ -0,0 +1,24 @@ +from src.caller_1 import process_data, process_extra +from src.caller_2 import Helper +from src.main import process_local + +def test_process_data(): + assert process_data(1, 2, 3, 4, 5, 6, 7, 8) == 33 + +def test_process_extra(): + assert process_extra(1, 2, 3, 4, 5, 6, 7, 8) == -1 + +def test_helper_class(): + h = Helper() + assert h.process_class_data(1, 2, 3, 4, 5, 6, 7, 8) == 40 + assert h.process_more_class_data(1, 2, 3, 4, 5, 6, 7, 8) == -8 + +def test_process_local(): + assert process_local(1, 2, 3, 4, 5, 6, 7, 8) == -36 + +if __name__ == "__main__": + test_process_data() + test_process_extra() + test_helper_class() + test_process_local() + print("All tests passed!") diff --git a/tests/input/project_multi_file_lec/src/__init__.py b/tests/input/project_multi_file_lec/src/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/input/project_multi_file_lec/src/main.py b/tests/input/project_multi_file_lec/src/main.py new file mode 100644 index 00000000..ca18eaf9 --- /dev/null +++ b/tests/input/project_multi_file_lec/src/main.py @@ -0,0 +1,12 @@ +from src.processor import process_data + +def main(): + """ + Main entry point of the application. + """ + sample_data = "hello world" + processed = process_data(sample_data) + print(f"Processed Data: {processed}") + +if __name__ == "__main__": + main() diff --git a/tests/input/project_multi_file_lec/src/processor.py b/tests/input/project_multi_file_lec/src/processor.py new file mode 100644 index 00000000..25dd083c --- /dev/null +++ b/tests/input/project_multi_file_lec/src/processor.py @@ -0,0 +1,16 @@ +from src.utils import Utility + +def process_data(data): + """ + Process some data and call the long_element_chain method from Utility. + """ + util = Utility() + my_call = util.long_chain["level1"]["level2"]["level3"]["level4"]["level5"]["level6"]["level7"] + lastVal = util.get_last_value() + fourthLevel = util.get_4th_level_value() + print(f"My call here: {my_call}") + print(f"Extracted Value1: {lastVal}") + print(f"Extracted Value2: {fourthLevel}") + return data.upper() + + diff --git a/tests/input/project_multi_file_lec/src/utils.py b/tests/input/project_multi_file_lec/src/utils.py new file mode 100644 index 00000000..00075717 --- /dev/null +++ b/tests/input/project_multi_file_lec/src/utils.py @@ -0,0 +1,23 @@ +class Utility: + def __init__(self): + self.long_chain = { + "level1": { + "level2": { + "level3": { + "level4": { + "level5": { + "level6": { + "level7": "deeply nested value" + } + } + } + } + } + } + } + + def get_last_value(self): + return self.long_chain["level1"]["level2"]["level3"]["level4"]["level5"]["level6"]["level7"] + + def get_4th_level_value(self): + return self.long_chain["level1"]["level2"]["level3"]["level4"] diff --git a/tests/input/project_multi_file_mim/src/__init__.py b/tests/input/project_multi_file_mim/src/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/input/project_multi_file_mim/src/main.py b/tests/input/project_multi_file_mim/src/main.py new file mode 100644 index 00000000..ca18eaf9 --- /dev/null +++ b/tests/input/project_multi_file_mim/src/main.py @@ -0,0 +1,12 @@ +from src.processor import process_data + +def main(): + """ + Main entry point of the application. + """ + sample_data = "hello world" + processed = process_data(sample_data) + print(f"Processed Data: {processed}") + +if __name__ == "__main__": + main() diff --git a/tests/input/project_multi_file_mim/src/processor.py b/tests/input/project_multi_file_mim/src/processor.py new file mode 100644 index 00000000..5afb1cd0 --- /dev/null +++ b/tests/input/project_multi_file_mim/src/processor.py @@ -0,0 +1,9 @@ +from src.utils import Utility + +def process_data(data): + """ + Process some data and call the unused_member_method from Utility. + """ + util = Utility() + util.unused_member_method(data) + return data.upper() diff --git a/tests/input/project_multi_file_mim/src/utils.py b/tests/input/project_multi_file_mim/src/utils.py new file mode 100644 index 00000000..5d117544 --- /dev/null +++ b/tests/input/project_multi_file_mim/src/utils.py @@ -0,0 +1,7 @@ +class Utility: + def unused_member_method(self, param): + """ + A method that accepts a parameter but doesn’t use it. + This demonstrates the member ignoring code smell. + """ + print("This method is defined but doesn’t use its parameter.") diff --git a/tests/input/project_multi_file_mim/tests/test_processor.py b/tests/input/project_multi_file_mim/tests/test_processor.py new file mode 100644 index 00000000..6bf0dc29 --- /dev/null +++ b/tests/input/project_multi_file_mim/tests/test_processor.py @@ -0,0 +1,8 @@ +from src.processor import process_data + +def test_process_data(): + """ + Test the process_data function. + """ + result = process_data("test") + assert result == "TEST" diff --git a/tests/input/project_multi_file_mim/tests/test_utils.py b/tests/input/project_multi_file_mim/tests/test_utils.py new file mode 100644 index 00000000..c5ac5b11 --- /dev/null +++ b/tests/input/project_multi_file_mim/tests/test_utils.py @@ -0,0 +1,10 @@ +from src.utils import Utility + +def test_unused_member_method(capfd): + """ + Test the unused_member_method to ensure it behaves as expected. + """ + util = Utility() + util.unused_member_method("test") + captured = capfd.readouterr() + assert "This method is defined but doesn’t use its parameter." in captured.out diff --git a/tests/input/project_repeated_calls/main.py b/tests/input/project_repeated_calls/main.py new file mode 100644 index 00000000..464953d0 --- /dev/null +++ b/tests/input/project_repeated_calls/main.py @@ -0,0 +1,85 @@ +# Example Python file with repeated calls smells + +class Demo: + def __init__(self, value): + self.value = value + + def compute(self): + return self.value * 2 + +# Simple repeated function calls +def simple_repeated_calls(): + value = Demo(10).compute() + result = value + Demo(10).compute() # Repeated call + return result + +# Repeated method calls on an object +def repeated_method_calls(): + demo = Demo(5) + first = demo.compute() + second = demo.compute() # Repeated call on the same object + return first + second + +# Repeated attribute access with method calls +def repeated_attribute_calls(): + demo = Demo(3) + first = demo.compute() + demo.value = 10 # Modify attribute + second = demo.compute() # Repeated but valid since the attribute was modified + return first + second + +# Repeated nested calls +def repeated_nested_calls(): + data = [Demo(i) for i in range(3)] + total = sum(demo.compute() for demo in data) + repeated = sum(demo.compute() for demo in data) # Repeated nested call + return total + repeated + +# Repeated calls in a loop +def repeated_calls_in_loop(): + results = [] + for i in range(5): + results.append(Demo(i).compute()) # Repeated call for each loop iteration + return results + +# Repeated calls with modifications in between +def repeated_calls_with_modification(): + demo = Demo(2) + first = demo.compute() + demo.value = 4 # Modify object + second = demo.compute() # Repeated but valid due to modification + return first + second + +# Repeated calls with mixed contexts +def repeated_calls_mixed_context(): + demo1 = Demo(1) + demo2 = Demo(2) + result1 = demo1.compute() + result2 = demo2.compute() + result3 = demo1.compute() # Repeated for demo1 + return result1 + result2 + result3 + +# Repeated calls with multiple arguments +def repeated_calls_with_args(): + result = max(Demo(1).compute(), Demo(1).compute()) # Repeated identical calls + return result + +# Repeated calls using a lambda +def repeated_lambda_calls(): + compute_demo = lambda x: Demo(x).compute() + first = compute_demo(3) + second = compute_demo(3) # Repeated lambda call + return first + second + +# Repeated calls with external dependencies +def repeated_calls_with_external_dependency(data): + result = len(data.get('key')) # Repeated external call + repeated = len(data.get('key')) + return result + repeated + +# Repeated calls with slightly different arguments +def repeated_calls_slightly_different(): + demo = Demo(10) + first = demo.compute() + second = Demo(20).compute() # Different object, not a true repeated call + return first + second diff --git a/tests/input/project_string_concat/__init__.py b/tests/input/project_string_concat/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/input/project_string_concat/main.py b/tests/input/project_string_concat/main.py new file mode 100644 index 00000000..b7be86dc --- /dev/null +++ b/tests/input/project_string_concat/main.py @@ -0,0 +1,137 @@ +class Demo: + def __init__(self) -> None: + self.test = "" + +def super_complex(): + result = '' + log = '' + for i in range(5): + result += "Iteration: " + str(i) + for j in range(3): + result += "Nested: " + str(j) # Contributing to `result` + log += "Log entry for i=" + str(i) + if i == 2: + result = "" # Resetting `result` + +def concat_with_for_loop_simple_attr(): + result = Demo() + for i in range(10): + result.test += str(i) # Simple concatenation + return result + +def concat_with_for_loop_simple_sub(): + result = {"key": ""} + for i in range(10): + result["key"] += str(i) # Simple concatenation + return result + +def concat_with_for_loop_simple(): + result = "" + for i in range(10): + result += str(i) # Simple concatenation + return result + +def concat_with_while_loop_variable_append(): + result = "" + i = 0 + while i < 5: + result += f"Value-{i}" # Using f-string inside while loop + i += 1 + return result + +def nested_loop_string_concat(): + result = "" + for i in range(2): + result = str(i) + for j in range(3): + result += f"({i},{j})" # Nested loop concatenation + return result + +def string_concat_with_condition(): + result = "" + for i in range(5): + if i % 2 == 0: + result += "Even" # Conditional concatenation + else: + result += "Odd" # Different condition + return result + +def concatenate_with_literal(): + result = "Start" + for i in range(4): + result += "-Next" # Concatenating a literal string + return result + +def complex_expression_concat(): + result = "" + for i in range(3): + result += "Complex" + str(i * i) + "End" # Expression inside concatenation + return result + +def repeated_variable_reassignment(): + result = Demo() + for i in range(2): + result.test = result.test + "First" + result.test = result.test + "Second" # Multiple reassignments + return result + +# Concatenation with % operator using only variables +def greet_user_with_percent(name): + greeting = "" + for i in range(2): + greeting += "Hello, " + "%s" % name + return greeting + +# Concatenation with str.format() using only variables +def describe_city_with_format(city): + description = "" + for i in range(2): + description = description + "I live in " + "the city of {}".format(city) + return description + +# Nested interpolation with % and concatenation +def person_description_with_percent(name, age): + description = "" + for i in range(2): + description += "Person: " + "%s, Age: %d" % (name, age) + return description + +# Multiple str.format() calls with concatenation +def values_with_format(x, y): + result = "" + for i in range(2): + result = result + "Value of x: {}".format(x) + ", and y: {:.2f}".format(y) + return result + +# Simple variable concatenation (edge case for completeness) +def simple_variable_concat(a: str, b: str): + result = Demo().test + for i in range(2): + result += a + b + return result + +def middle_var_concat(): + result = '' + for i in range(3): + result = str(i) + result + str(i) + return result + +def end_var_concat(): + result = '' + for i in range(3): + result = str(i) + result + return result + +def concat_referenced_in_loop(): + result = "" + for i in range(3): + result += "Complex" + str(i * i) + "End" # Expression inside concatenation + print(result) + return result + +def concat_not_in_loop(): + name = "Bob" + name += "Ross" + return name + +simple_variable_concat("Hello", " World ") \ No newline at end of file diff --git a/tests/input/project_string_concat/test_main.py b/tests/input/project_string_concat/test_main.py new file mode 100644 index 00000000..461ccccb --- /dev/null +++ b/tests/input/project_string_concat/test_main.py @@ -0,0 +1,86 @@ +import pytest +from .main import ( + concat_with_for_loop_simple, + complex_expression_concat, + concat_with_for_loop_simple_attr, + concat_with_for_loop_simple_sub, + concat_with_while_loop_variable_append, + concatenate_with_literal, + simple_variable_concat, + string_concat_with_condition, + nested_loop_string_concat, + repeated_variable_reassignment, + greet_user_with_percent, + describe_city_with_format, + person_description_with_percent, + values_with_format, + middle_var_concat, + end_var_concat +) + +def test_concat_with_for_loop_simple_attr(): + result = concat_with_for_loop_simple_attr() + assert result.test == ''.join(str(i) for i in range(10)) + +def test_concat_with_for_loop_simple_sub(): + result = concat_with_for_loop_simple_sub() + assert result["key"] == ''.join(str(i) for i in range(10)) + +def test_concat_with_for_loop_simple(): + result = concat_with_for_loop_simple() + assert result == ''.join(str(i) for i in range(10)) + +def test_concat_with_while_loop_variable_append(): + result = concat_with_while_loop_variable_append() + assert result == ''.join(f"Value-{i}" for i in range(5)) + +def test_nested_loop_string_concat(): + result = nested_loop_string_concat() + expected = "1(1,0)(1,1)(1,2)" + assert result == expected + +def test_string_concat_with_condition(): + result = string_concat_with_condition() + expected = ''.join("Even" if i % 2 == 0 else "Odd" for i in range(5)) + assert result == expected + +def test_concatenate_with_literal(): + result = concatenate_with_literal() + assert result == "Start" + "-Next" * 4 + +def test_complex_expression_concat(): + result = complex_expression_concat() + expected = ''.join(f"Complex{i*i}End" for i in range(3)) + assert result == expected + +def test_repeated_variable_reassignment(): + result = repeated_variable_reassignment() + assert result.test == ("FirstSecond" * 2) + +def test_greet_user_with_percent(): + result = greet_user_with_percent("Alice") + assert result == ("Hello, Alice" * 2) + +def test_describe_city_with_format(): + result = describe_city_with_format("London") + assert result == ("I live in the city of London" * 2) + +def test_person_description_with_percent(): + result = person_description_with_percent("Bob", 25) + assert result == ("Person: Bob, Age: 25" * 2) + +def test_values_with_format(): + result = values_with_format(42, 3.14) + assert result == ("Value of x: 42, and y: 3.14" * 2) + +def test_simple_variable_concat(): + result = simple_variable_concat("foo", "bar") + assert result == ("foobar" * 2) + +def test_end_var_concat(): + result = end_var_concat() + assert result == ("210") + +def test_middle_var_concat(): + result = middle_var_concat() + assert result == ("210012") diff --git a/tests/input/repeated_calls_examples.py b/tests/input/repeated_calls_examples.py new file mode 100644 index 00000000..464953d0 --- /dev/null +++ b/tests/input/repeated_calls_examples.py @@ -0,0 +1,85 @@ +# Example Python file with repeated calls smells + +class Demo: + def __init__(self, value): + self.value = value + + def compute(self): + return self.value * 2 + +# Simple repeated function calls +def simple_repeated_calls(): + value = Demo(10).compute() + result = value + Demo(10).compute() # Repeated call + return result + +# Repeated method calls on an object +def repeated_method_calls(): + demo = Demo(5) + first = demo.compute() + second = demo.compute() # Repeated call on the same object + return first + second + +# Repeated attribute access with method calls +def repeated_attribute_calls(): + demo = Demo(3) + first = demo.compute() + demo.value = 10 # Modify attribute + second = demo.compute() # Repeated but valid since the attribute was modified + return first + second + +# Repeated nested calls +def repeated_nested_calls(): + data = [Demo(i) for i in range(3)] + total = sum(demo.compute() for demo in data) + repeated = sum(demo.compute() for demo in data) # Repeated nested call + return total + repeated + +# Repeated calls in a loop +def repeated_calls_in_loop(): + results = [] + for i in range(5): + results.append(Demo(i).compute()) # Repeated call for each loop iteration + return results + +# Repeated calls with modifications in between +def repeated_calls_with_modification(): + demo = Demo(2) + first = demo.compute() + demo.value = 4 # Modify object + second = demo.compute() # Repeated but valid due to modification + return first + second + +# Repeated calls with mixed contexts +def repeated_calls_mixed_context(): + demo1 = Demo(1) + demo2 = Demo(2) + result1 = demo1.compute() + result2 = demo2.compute() + result3 = demo1.compute() # Repeated for demo1 + return result1 + result2 + result3 + +# Repeated calls with multiple arguments +def repeated_calls_with_args(): + result = max(Demo(1).compute(), Demo(1).compute()) # Repeated identical calls + return result + +# Repeated calls using a lambda +def repeated_lambda_calls(): + compute_demo = lambda x: Demo(x).compute() + first = compute_demo(3) + second = compute_demo(3) # Repeated lambda call + return first + second + +# Repeated calls with external dependencies +def repeated_calls_with_external_dependency(data): + result = len(data.get('key')) # Repeated external call + repeated = len(data.get('key')) + return result + repeated + +# Repeated calls with slightly different arguments +def repeated_calls_slightly_different(): + demo = Demo(10) + first = demo.compute() + second = Demo(20).compute() # Different object, not a true repeated call + return first + second diff --git a/tests/input/string_concat_sample.py b/tests/input/string_concat_sample.py new file mode 100644 index 00000000..b7be86dc --- /dev/null +++ b/tests/input/string_concat_sample.py @@ -0,0 +1,137 @@ +class Demo: + def __init__(self) -> None: + self.test = "" + +def super_complex(): + result = '' + log = '' + for i in range(5): + result += "Iteration: " + str(i) + for j in range(3): + result += "Nested: " + str(j) # Contributing to `result` + log += "Log entry for i=" + str(i) + if i == 2: + result = "" # Resetting `result` + +def concat_with_for_loop_simple_attr(): + result = Demo() + for i in range(10): + result.test += str(i) # Simple concatenation + return result + +def concat_with_for_loop_simple_sub(): + result = {"key": ""} + for i in range(10): + result["key"] += str(i) # Simple concatenation + return result + +def concat_with_for_loop_simple(): + result = "" + for i in range(10): + result += str(i) # Simple concatenation + return result + +def concat_with_while_loop_variable_append(): + result = "" + i = 0 + while i < 5: + result += f"Value-{i}" # Using f-string inside while loop + i += 1 + return result + +def nested_loop_string_concat(): + result = "" + for i in range(2): + result = str(i) + for j in range(3): + result += f"({i},{j})" # Nested loop concatenation + return result + +def string_concat_with_condition(): + result = "" + for i in range(5): + if i % 2 == 0: + result += "Even" # Conditional concatenation + else: + result += "Odd" # Different condition + return result + +def concatenate_with_literal(): + result = "Start" + for i in range(4): + result += "-Next" # Concatenating a literal string + return result + +def complex_expression_concat(): + result = "" + for i in range(3): + result += "Complex" + str(i * i) + "End" # Expression inside concatenation + return result + +def repeated_variable_reassignment(): + result = Demo() + for i in range(2): + result.test = result.test + "First" + result.test = result.test + "Second" # Multiple reassignments + return result + +# Concatenation with % operator using only variables +def greet_user_with_percent(name): + greeting = "" + for i in range(2): + greeting += "Hello, " + "%s" % name + return greeting + +# Concatenation with str.format() using only variables +def describe_city_with_format(city): + description = "" + for i in range(2): + description = description + "I live in " + "the city of {}".format(city) + return description + +# Nested interpolation with % and concatenation +def person_description_with_percent(name, age): + description = "" + for i in range(2): + description += "Person: " + "%s, Age: %d" % (name, age) + return description + +# Multiple str.format() calls with concatenation +def values_with_format(x, y): + result = "" + for i in range(2): + result = result + "Value of x: {}".format(x) + ", and y: {:.2f}".format(y) + return result + +# Simple variable concatenation (edge case for completeness) +def simple_variable_concat(a: str, b: str): + result = Demo().test + for i in range(2): + result += a + b + return result + +def middle_var_concat(): + result = '' + for i in range(3): + result = str(i) + result + str(i) + return result + +def end_var_concat(): + result = '' + for i in range(3): + result = str(i) + result + return result + +def concat_referenced_in_loop(): + result = "" + for i in range(3): + result += "Complex" + str(i * i) + "End" # Expression inside concatenation + print(result) + return result + +def concat_not_in_loop(): + name = "Bob" + name += "Ross" + return name + +simple_variable_concat("Hello", " World ") \ No newline at end of file diff --git a/tests/measurements/test_codecarbon_energy_meter.py b/tests/measurements/test_codecarbon_energy_meter.py new file mode 100644 index 00000000..0e2d9b6e --- /dev/null +++ b/tests/measurements/test_codecarbon_energy_meter.py @@ -0,0 +1,92 @@ +import pytest +import logging +from pathlib import Path +import subprocess +import pandas as pd +from unittest.mock import patch +import sys + +from ecooptimizer.measurements.codecarbon_energy_meter import CodeCarbonEnergyMeter + + +@pytest.fixture +def energy_meter(): + return CodeCarbonEnergyMeter() + + +@patch("codecarbon.EmissionsTracker.start") +@patch("codecarbon.EmissionsTracker.stop", return_value=0.45) +@patch("subprocess.run") +def test_measure_energy_success(mock_run, mock_stop, mock_start, energy_meter, caplog): + mock_run.return_value = subprocess.CompletedProcess( + args=["python3", "../input/project_car_stuff/main.py"], returncode=0 + ) + file_path = Path("../input/project_car_stuff/main.py") + with caplog.at_level(logging.INFO): + energy_meter.measure_energy(file_path) + + assert mock_run.call_count >= 1 + mock_run.assert_any_call( + [sys.executable, file_path], + capture_output=True, + text=True, + check=True, + ) + mock_start.assert_called_once() + mock_stop.assert_called_once() + assert "CodeCarbon measurement completed successfully." in caplog.text + assert energy_meter.emissions == 0.45 + + +@patch("codecarbon.EmissionsTracker.start") +@patch("codecarbon.EmissionsTracker.stop", return_value=0.45) +@patch("subprocess.run", side_effect=subprocess.CalledProcessError(1, "python3")) +def test_measure_energy_failure(mock_run, mock_stop, mock_start, energy_meter, caplog): + file_path = Path("../input/project_car_stuff/main.py") + with caplog.at_level(logging.ERROR): + energy_meter.measure_energy(file_path) + + mock_start.assert_called_once() + mock_run.assert_called_once() + mock_stop.assert_called_once() + assert "Error executing file" in caplog.text + assert ( + energy_meter.emissions_data is None + ) # since execution failed, emissions data should be None + + +@patch("pandas.read_csv") +@patch("pathlib.Path.exists", return_value=True) # mock file existence +def test_extract_emissions_csv_success(mock_exists, mock_read_csv, energy_meter): # noqa: ARG001 + # simulate DataFrame return value + mock_read_csv.return_value = pd.DataFrame( + [{"timestamp": "2025-03-01 12:00:00", "emissions": 0.45}] + ) + + csv_path = Path("dummy_path.csv") # fake path + result = energy_meter.extract_emissions_csv(csv_path) + + assert isinstance(result, dict) + assert "emissions" in result + assert result["emissions"] == 0.45 + + +@patch("pandas.read_csv", side_effect=Exception("File read error")) +@patch("pathlib.Path.exists", return_value=True) # mock file existence +def test_extract_emissions_csv_failure(mock_exists, mock_read_csv, energy_meter, caplog): # noqa: ARG001 + csv_path = Path("dummy_path.csv") # fake path + with caplog.at_level(logging.INFO): + result = energy_meter.extract_emissions_csv(csv_path) + + assert result is None # since reading the CSV fails, result should be None + assert "Error reading file" in caplog.text + + +@patch("pathlib.Path.exists", return_value=False) +def test_extract_emissions_csv_missing_file(mock_exists, energy_meter, caplog): # noqa: ARG001 + csv_path = Path("dummy_path.csv") # fake path + with caplog.at_level(logging.INFO): + result = energy_meter.extract_emissions_csv(csv_path) + + assert result is None # since file path does not exist, result should be None + assert "File 'dummy_path.csv' does not exist." in caplog.text diff --git a/tests/refactorers/test_list_comp_any_all_refactor.py b/tests/refactorers/test_list_comp_any_all_refactor.py new file mode 100644 index 00000000..bf059400 --- /dev/null +++ b/tests/refactorers/test_list_comp_any_all_refactor.py @@ -0,0 +1,121 @@ +import pytest +import textwrap +from pathlib import Path +from ecooptimizer.refactorers.concrete.list_comp_any_all import UseAGeneratorRefactorer +from ecooptimizer.data_types import UGESmell, Occurence +from ecooptimizer.utils.smell_enums import PylintSmell + + +@pytest.fixture +def refactorer(): + return UseAGeneratorRefactorer() + + +def create_smell(occurences: list[int]): + """Factory function to create a smell object""" + + def _create(): + return UGESmell( + path="fake.py", + module="some_module", + obj=None, + type="performance", + symbol="use-a-generator", + message="Consider using a generator expression instead of a list comprehension.", + messageId=PylintSmell.USE_A_GENERATOR.value, + confidence="INFERENCE", + occurences=[ + Occurence( + line=occ, + endLine=occ, + column=999, + endColumn=999, + ) + for occ in occurences + ], + additionalInfo=None, + ) + + return _create + + +def test_ugen_basic_all_case(source_files, refactorer): + """ + Tests basic transformation of list comprehensions in `all()` calls. + """ + test_dir = Path(source_files, "temp_basic_ugen") + test_dir.mkdir(exist_ok=True) + + file1 = test_dir / "ugen_def.py" + file1.write_text( + textwrap.dedent(""" + def all_non_negative(numbers): + return all([num >= 0 for num in numbers]) + """) + ) + + smell = create_smell(occurences=[3])() + refactorer.refactor(file1, test_dir, smell, Path("fake.py")) + + expected_file1 = textwrap.dedent(""" + def all_non_negative(numbers): + return all(num >= 0 for num in numbers) + """) + + assert file1.read_text().strip() == expected_file1.strip() + + +def test_ugen_basic_any_case(source_files, refactorer): + """ + Tests basic transformation of list comprehensions in `any()` calls. + """ + test_dir = Path(source_files, "temp_basic_ugen_any") + test_dir.mkdir(exist_ok=True) + + file1 = test_dir / "ugen_def.py" + file1.write_text( + textwrap.dedent(""" + def contains_large_strings(strings): + return any([len(s) > 10 for s in strings]) + """) + ) + + smell = create_smell(occurences=[3])() + refactorer.refactor(file1, test_dir, smell, Path("fake.py")) + + expected_file1 = textwrap.dedent(""" + def contains_large_strings(strings): + return any(len(s) > 10 for s in strings) + """) + + assert file1.read_text().strip() == expected_file1.strip() + + +def test_ugen_multiline_comprehension(source_files, refactorer): + """ + Tests that multi-line list comprehensions inside `any()` or `all()` are refactored correctly. + """ + test_dir = Path(source_files, "temp_multiline_ugen") + test_dir.mkdir(exist_ok=True) + + file1 = test_dir / "ugem_def.py" + file1.write_text( + textwrap.dedent(""" + def has_long_words(words): + return any([ + len(word) > 8 + for word in words + ]) + """) + ) + + smell = create_smell(occurences=[3])() + refactorer.refactor(file1, test_dir, smell, Path("fake.py")) + + expected_file1 = textwrap.dedent(""" + def has_long_words(words): + return any(len(word) > 8 + for word in words) + """) + + assert file1.read_text().strip() == expected_file1.strip() diff --git a/tests/refactorers/test_long_element_chain_refactor.py b/tests/refactorers/test_long_element_chain_refactor.py new file mode 100644 index 00000000..c6102ea1 --- /dev/null +++ b/tests/refactorers/test_long_element_chain_refactor.py @@ -0,0 +1,375 @@ +import pytest +import textwrap +from pathlib import Path + +from ecooptimizer.refactorers.concrete.long_element_chain import LongElementChainRefactorer +from ecooptimizer.data_types import LECSmell, Occurence +from ecooptimizer.utils.smell_enums import CustomSmell + + +@pytest.fixture +def refactorer(): + return LongElementChainRefactorer() + + +def create_smell(occurences: list[int]): + """Factory function to create a smell object""" + + def _create(): + return LECSmell( + confidence="UNDEFINED", + message="Dictionary chain too long (6/4)", + obj="lec_function", + symbol="long-element-chain", + type="convention", + messageId=CustomSmell.LONG_ELEMENT_CHAIN.value, + path="fake.py", + module="some_module", + occurences=[ + Occurence( + line=occ, + endLine=occ, + column=0, + endColumn=999, + ) + for occ in occurences + ], + additionalInfo=None, + ) + + return _create + + +def test_lec_basic_case(source_files, refactorer): + """ + Tests that the long element chain refactorer: + - Identifies nested dictionary access + - Flattens the access pattern + - Updates the dictionary definition + """ + + # --- File 1: Defines and uses the nested dictionary --- + test_dir = Path(source_files, "temp_basic_lec") + test_dir.mkdir(exist_ok=True) + + file1 = test_dir / "dict_def.py" + file1.write_text( + textwrap.dedent("""\ + config = { + "server": { + "host": "localhost", + "port": 8080, + "settings": { + "timeout": 30, + "retry": 3 + } + }, + "database": { + "type": "postgresql", + "credentials": { + "username": "admin", + "password": "secret" + } + } + } + + # Line where the smell is detected + timeout = config["server"]["settings"]["timeout"] + """) + ) + + smell = create_smell(occurences=[20])() + + refactorer.refactor(file1, test_dir, smell, Path("fake.py")) + + # --- Expected Result for File 1 --- + # The dictionary should be flattened and accesses should be updated + expected_file1 = textwrap.dedent("""config = {"server_host": "localhost","server_port": 8080,"server_settings_timeout": 30,"server_settings_retry": 3,"database_type": "postgresql","database_credentials_username": "admin","database_credentials_password": "secret"} + +# Line where the smell is detected +timeout = config['server_settings_timeout'] + """) + + # Check if the refactoring worked + assert file1.read_text().strip() == expected_file1.strip() + + +def test_lec_multiple_files(source_files, refactorer): + """ + Tests that the refactorer updates dictionary accesses across multiple files. + """ + + # --- File 1: Defines the nested dictionary --- + test_dir = Path(source_files, "temp_multi_lec") + test_dir.mkdir(exist_ok=True) + + file1 = test_dir / "dict_def.py" + file1.write_text( + textwrap.dedent("""\ + class Utility: + def __init__(self): + self.long_chain = { + "level1": { + "level2": { + "level3": { + "level4": { + "level5": { + "level6": { + "level7": "deeply nested value" + } + } + } + } + } + } + } + + def get_last_value(self): + return self.long_chain["level1"]["level2"]["level3"]["level4"]["level5"]["level6"]["level7"] + + def get_4th_level_value(self): + return self.long_chain["level1"]["level2"]["level3"]["level4"] + """) + ) + + # --- File 2: Uses the nested dictionary --- + file2 = test_dir / "dict_user.py" + file2.write_text( + textwrap.dedent("""\ + from src.utils import Utility + + def process_data(data): + util = Utility() + my_call = util.long_chain["level1"]["level2"]["level3"]["level4"]["level5"]["level6"]["level7"] + lastVal = util.get_last_value() + fourthLevel = util.get_4th_level_value() + return data.upper() + """) + ) + + smell = create_smell(occurences=[20])() + + refactorer.refactor(file1, test_dir, smell, Path("fake.py")) + + # --- Expected Result for File 1 --- + expected_file1 = textwrap.dedent("""\ + class Utility: + def __init__(self): + self.long_chain = {"level1_level2_level3_level4": {"level5": {"level6": {"level7": "deeply nested value"}}}} + + def get_last_value(self): + return self.long_chain['level1_level2_level3_level4']['level5']['level6']['level7'] + + def get_4th_level_value(self): + return self.long_chain['level1_level2_level3_level4'] + """) + + # --- Expected Result for File 2 --- + expected_file2 = textwrap.dedent("""\ + from src.utils import Utility + + def process_data(data): + util = Utility() + my_call = util.long_chain['level1_level2_level3_level4']['level5']['level6']['level7'] + lastVal = util.get_last_value() + fourthLevel = util.get_4th_level_value() + return data.upper() + """) + + # Check if the refactoring worked + assert file1.read_text().strip() == expected_file1.strip() + assert file2.read_text().strip() == expected_file2.strip() + + +def test_lec_attribute_access(source_files, refactorer): + """ + Tests refactoring of dictionary accessed via class attribute. + """ + + # --- File 1: Defines and uses the nested dictionary as class attribute --- + test_dir = Path(source_files, "temp_attr_lec") + test_dir.mkdir(exist_ok=True) + + file1 = test_dir / "class_dict.py" + file1.write_text( + textwrap.dedent("""\ + class ConfigManager: + def __init__(self): + self.config = { + "server": { + "host": "localhost", + "port": 8080, + "settings": { + "timeout": 30, + "retry": 3 + } + } + } + + def get_timeout(self): + return self.config["server"]["settings"]["timeout"] + + manager = ConfigManager() + timeout = manager.config["server"]["settings"]["timeout"] + """) + ) + + smell = create_smell(occurences=[15])() + + refactorer.refactor(file1, test_dir, smell, Path("fake.py")) + + # --- Expected Result for File 1 --- + expected_file1 = textwrap.dedent("""\ + class ConfigManager: + def __init__(self): + self.config = {"server_host": "localhost","server_port": 8080,"server_settings_timeout": 30,"server_settings_retry": 3} + + def get_timeout(self): + return self.config['server_settings_timeout'] + +manager = ConfigManager() +timeout = manager.config['server_settings_timeout'] + """) + + # Check if the refactoring worked + assert file1.read_text().strip() == expected_file1.strip() + + +def test_lec_shallow_access_ignored(source_files, refactorer): + """ + Tests that refactoring is skipped when dictionary access is too shallow. + """ + + # --- File with shallow dictionary access --- + test_dir = Path(source_files, "temp_shallow_lec") + test_dir.mkdir(exist_ok=True) + + file1 = test_dir / "shallow_dict.py" + original_content = textwrap.dedent("""\ + config = { + "server": { + "host": "localhost", + "port": 8080 + }, + "database": { + "type": "postgresql" + } + } + + # Only one level deep + host = config["server"] + """) + + file1.write_text(original_content) + + smell = create_smell(occurences=[11])() + + refactorer.refactor(file1, test_dir, smell, Path("fake.py")) + + # Refactoring should be skipped because access is too shallow + assert file1.read_text().strip() == original_content.strip() + + +# def test_lec_multiple_occurrences(source_files, refactorer): +# """ +# Tests refactoring when there are multiple dictionary access patterns in the same file. +# """ + +# # --- File with multiple dictionary accesses --- +# test_dir = Path(source_files, "temp_multi_occur_lec") +# test_dir.mkdir(exist_ok=True) + +# file1 = test_dir / "multi_access.py" +# file1.write_text( +# textwrap.dedent("""\ +# settings = { +# "app": { +# "name": "EcoOptimizer", +# "version": "1.0", +# "config": { +# "debug": True, +# "logging": { +# "level": "INFO", +# "format": "standard" +# } +# } +# } +# } + +# # Multiple deep accesses +# print(settings["app"]["config"]["debug"]) +# print(settings["app"]["config"]["logging"]["level"]) +# print(settings["app"]["config"]["logging"]["format"]) +# """) +# ) + +# smell = create_smell(occurences=[15])() + +# refactorer.refactor(file1, test_dir, smell, Path("fake.py")) + +# # --- Expected Result --- +# expected_file1 = textwrap.dedent("""\ +# settings = {"app_name": "EcoOptimizer", "app_version": "1.0", "app_config_debug": true, "app_config_logging_level": "INFO", "app_config_logging_format": "standard"} + +# # Multiple deep accesses +# debug_mode = settings["app_config_debug"] +# log_level = settings["app_config_logging_level"] +# app_name = settings["app_name"] +# """) + +# print("this is the file: " + file1.read_text().strip()) +# print("this is the expected: " + expected_file1.strip()) +# print(file1.read_text().strip() == expected_file1.strip()) +# # Check if the refactoring worked +# assert file1.read_text().strip() == expected_file1.strip() + + +def test_lec_mixed_access_depths(source_files, refactorer): + """ + Tests refactoring when there are different depths of dictionary access. + """ + # --- File with different depths of dictionary access --- + test_dir = Path(source_files, "temp_mixed_depth_lec") + test_dir.mkdir(exist_ok=True) + + file1 = test_dir / "mixed_depth.py" + file1.write_text( + textwrap.dedent("""\ + data = { + "user": { + "profile": { + "name": "John Doe", + "email": "john@example.com", + "preferences": { + "theme": "dark", + "notifications": True + } + }, + "role": "admin" + } + } + + # Different access depths + name = data["user"]["profile"]["name"] + theme = data["user"]["profile"]["preferences"]["theme"] + role = data["user"]["role"] + """) + ) + + smell = create_smell(occurences=[16])() + + refactorer.refactor(file1, test_dir, smell, Path("fake.py")) + + # --- Expected Result --- + # Note: The min nesting level determines what gets flattened + expected_file1 = textwrap.dedent("""\ + data = {"user_profile": {"name": "John Doe","email": "john@example.com","preferences": {"theme": "dark","notifications": true}},"user_role": "admin"} + + # Different access depths + name = data['user_profile']['name'] + theme = data['user_profile']['preferences']['theme'] + role = data['user_role'] + """) + + # Check if the refactoring worked + assert file1.read_text().strip() == expected_file1.strip() diff --git a/tests/refactorers/test_long_lambda_element_refactoring.py b/tests/refactorers/test_long_lambda_element_refactoring.py new file mode 100644 index 00000000..55b35286 --- /dev/null +++ b/tests/refactorers/test_long_lambda_element_refactoring.py @@ -0,0 +1,236 @@ +import pytest +import textwrap +from unittest.mock import patch +from pathlib import Path + +from ecooptimizer.refactorers.concrete.long_lambda_function import ( + LongLambdaFunctionRefactorer, +) +from ecooptimizer.data_types import Occurence, LLESmell +from ecooptimizer.utils.smell_enums import CustomSmell + + +@pytest.fixture +def refactorer(): + return LongLambdaFunctionRefactorer() + + +def create_smell(occurences: list[int]): + """Factory function to create lambda smell objects.""" + return lambda: LLESmell( + path="fake.py", + module="some_module", + obj=None, + type="performance", + symbol="long-lambda", + message="Lambda too long", + messageId=CustomSmell.LONG_LAMBDA_EXPR.value, + confidence="UNDEFINED", + occurences=[ + Occurence(line=occ, endLine=999, column=999, endColumn=999) for occ in occurences + ], + additionalInfo=None, + ) + + +def normalize_code(code: str) -> str: + """Normalize whitespace for reliable comparisons.""" + return "\n".join(line.rstrip() for line in code.strip().splitlines()) + "\n" + + +def test_basic_lambda_conversion(refactorer): + """Tests conversion of simple single-line lambda.""" + code = textwrap.dedent( + """ + def example(): + my_lambda = lambda x: x + 1 + """ + ) + + expected = textwrap.dedent( + """ + def example(): + def converted_lambda_3(x): + result = x + 1 + return result + + my_lambda = converted_lambda_3 + """ + ) + + smell = create_smell([3])() + with ( + patch.object(Path, "read_text", return_value=code), + patch.object(Path, "write_text") as mock_write, + ): + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + + written = mock_write.call_args[0][0] + print(written) + assert normalize_code(written) == normalize_code(expected) + + +def test_no_extra_print_statements(refactorer): + """Ensures no print statements are added unnecessarily.""" + code = textwrap.dedent( + """ + def example(): + processor = lambda x: x.strip().lower() + """ + ) + + expected = textwrap.dedent( + """ + def example(): + def converted_lambda_3(x): + result = x.strip().lower() + return result + + processor = converted_lambda_3 + """ + ) + + smell = create_smell([3])() + with ( + patch.object(Path, "read_text", return_value=code), + patch.object(Path, "write_text") as mock_write, + ): + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + written = mock_write.call_args[0][0] + assert "print(" not in written + assert normalize_code(written) == normalize_code(expected) + + +def test_lambda_in_function_argument(refactorer): + """Tests lambda passed as argument to another function.""" + code = textwrap.dedent( + """ + def process_data(): + results = list(map(lambda x: x * 2, [1, 2, 3])) + """ + ) + + expected = textwrap.dedent( + """ + def process_data(): + def converted_lambda_3(x): + result = x * 2 + return result + + results = list(map(converted_lambda_3, [1, 2, 3])) + """ + ) + + smell = create_smell([3])() + with ( + patch.object(Path, "read_text", return_value=code), + patch.object(Path, "write_text") as mock_write, + ): + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + + written = mock_write.call_args[0][0] + assert normalize_code(written) == normalize_code(expected) + + +def test_multi_argument_lambda(refactorer): + """Tests lambda with multiple parameters passed as argument.""" + code = textwrap.dedent( + """ + from functools import reduce + def calculate(): + total = reduce(lambda a, b: a + b, [1, 2, 3, 4]) + """ + ) + + expected = textwrap.dedent( + """ + from functools import reduce + def calculate(): + def converted_lambda_4(a, b): + result = a + b + return result + + total = reduce(converted_lambda_4, [1, 2, 3, 4]) + """ + ) + + smell = create_smell([4])() + with ( + patch.object(Path, "read_text", return_value=code), + patch.object(Path, "write_text") as mock_write, + ): + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + written = mock_write.call_args[0][0] + assert normalize_code(written) == normalize_code(expected) + + +def test_lambda_with_keyword_arguments(refactorer): + """Tests lambda used with keyword arguments.""" + code = textwrap.dedent( + """ + def configure_settings(): + button = Button( + text="Submit", + on_click=lambda event: handle_event(event, retries=3) + ) + """ + ) + + expected = textwrap.dedent( + """ + def configure_settings(): + def converted_lambda_5(event): + result = handle_event(event, retries=3) + return result + + button = Button( + text="Submit", + on_click=converted_lambda_5 + ) + """ + ) + + smell = create_smell([5])() + with ( + patch.object(Path, "read_text", return_value=code), + patch.object(Path, "write_text") as mock_write, + ): + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + written = mock_write.call_args[0][0] + print(written) + assert normalize_code(written) == normalize_code(expected) + + +def test_very_long_lambda_function(refactorer): + """Tests refactoring of a very long lambda function that spans multiple lines.""" + code = textwrap.dedent( + """ + def calculate(): + value = ( + lambda a, b, c: a + b + c + a * b - c / (a + b) + a - b * c + a**2 - b**2 + a*b + a/(b+c) - c*(a-b) + (a+b+c) + )(1, 2, 3) + """ + ) + + expected = textwrap.dedent( + """ + def calculate(): + def converted_lambda_4(a, b, c): + result = a + b + c + a * b - c / (a + b) + a - b * c + a**2 - b**2 + a*b + a/(b+c) - c*(a-b) + (a+b+c) + return result + + value = ( + converted_lambda_4 + )(1, 2, 3) + """ + ) + + smell = create_smell([4])() + with ( + patch.object(Path, "read_text", return_value=code), + patch.object(Path, "write_text") as mock_write, + ): + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + written = mock_write.call_args[0][0] + print(written) + assert normalize_code(written) == normalize_code(expected) diff --git a/tests/refactorers/test_long_message_chain_refactoring.py b/tests/refactorers/test_long_message_chain_refactoring.py new file mode 100644 index 00000000..dfd9760c --- /dev/null +++ b/tests/refactorers/test_long_message_chain_refactoring.py @@ -0,0 +1,261 @@ +import pytest +import textwrap +from unittest.mock import patch +from pathlib import Path + +from ecooptimizer.refactorers.concrete.long_message_chain import ( + LongMessageChainRefactorer, +) +from ecooptimizer.data_types import Occurence, LMCSmell +from ecooptimizer.utils.smell_enums import CustomSmell + + +@pytest.fixture +def refactorer(): + return LongMessageChainRefactorer() + + +def create_smell(occurences: list[int]): + """Factory function to create a smell object for long message chains.""" + + def _create(): + return LMCSmell( + path="fake.py", + module="some_module", + obj=None, + type="convention", + symbol="long-message-chain", + message="Method chain too long", + messageId=CustomSmell.LONG_MESSAGE_CHAIN.value, + confidence="UNDEFINED", + occurences=[ + Occurence(line=occ, endLine=999, column=999, endColumn=999) + for occ in occurences + ], + additionalInfo=None, + ) + + return _create + + +def test_basic_method_chain_refactoring(refactorer): + """Tests refactoring of a basic method chain.""" + code = textwrap.dedent( + """ + def example(): + text = "Hello" + result = text.strip().lower().replace("|", "-").title() + """ + ) + expected_code = textwrap.dedent( + """ + def example(): + text = "Hello" + intermediate_0 = text.strip() + intermediate_1 = intermediate_0.lower() + intermediate_2 = intermediate_1.replace("|", "-") + result = intermediate_2.title() + """ + ) + + smell = create_smell([4])() + with ( + patch.object(Path, "read_text", return_value=code), + patch.object(Path, "write_text") as mock_write_text, + ): + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + + mock_write_text.assert_called_once() + written_code = mock_write_text.call_args[0][0] + assert written_code.strip() == expected_code.strip() + + +def test_fstring_chain_refactoring(refactorer): + """Tests refactoring of a long message chain with an f-string.""" + code = textwrap.dedent( + """ + def example(): + name = "John" + greeting = f"Hello {name}".strip().replace(" ", "-").upper() + """ + ) + expected_code = textwrap.dedent( + """ + def example(): + name = "John" + intermediate_0 = f"Hello {name}" + intermediate_1 = intermediate_0.strip() + intermediate_2 = intermediate_1.replace(" ", "-") + greeting = intermediate_2.upper() + """ + ) + + smell = create_smell([4])() + with ( + patch.object(Path, "read_text", return_value=code), + patch.object(Path, "write_text") as mock_write_text, + ): + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + + mock_write_text.assert_called_once() + written_code = mock_write_text.call_args[0][0] + assert written_code.strip() == expected_code.strip() + + +def test_modifications_if_no_long_chain(refactorer): + """Ensures modifications occur even if the method chain isnt long.""" + code = textwrap.dedent( + """ + def example(): + text = "Hello" + result = text.strip().lower() + """ + ) + + expected_code = textwrap.dedent( + """ + def example(): + text = "Hello" + intermediate_0 = text.strip() + result = intermediate_0.lower() + """ + ) + + smell = create_smell([4])() + with ( + patch.object(Path, "read_text", return_value=code), + patch.object(Path, "write_text") as mock_write_text, + ): + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + + mock_write_text.assert_called_once() + written_code = mock_write_text.call_args[0][0] + assert written_code.strip() == expected_code.strip() + + +def test_proper_indentation_preserved(refactorer): + """Ensures indentation is preserved after refactoring.""" + code = textwrap.dedent( + """ + def example(): + if True: + text = "Hello" + result = text.strip().lower().replace("|", "-").title() + """ + ) + expected_code = textwrap.dedent( + """ + def example(): + if True: + text = "Hello" + intermediate_0 = text.strip() + intermediate_1 = intermediate_0.lower() + intermediate_2 = intermediate_1.replace("|", "-") + result = intermediate_2.title() + """ + ) + + smell = create_smell([5])() + with ( + patch.object(Path, "read_text", return_value=code), + patch.object(Path, "write_text") as mock_write_text, + ): + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + + mock_write_text.assert_called_once() + written_code = mock_write_text.call_args[0][0] + print(written_code, "\n") + assert written_code.splitlines() == expected_code.splitlines() + + +def test_method_chain_with_arguments(refactorer): + """Tests refactoring of method chains containing method arguments.""" + code = textwrap.dedent( + """ + def example(): + text = "Hello" + result = text.strip().replace("H", "J").lower().title() + """ + ) + expected_code = textwrap.dedent( + """ + def example(): + text = "Hello" + intermediate_0 = text.strip() + intermediate_1 = intermediate_0.replace("H", "J") + intermediate_2 = intermediate_1.lower() + result = intermediate_2.title() + """ + ) + + smell = create_smell([4])() + with ( + patch.object(Path, "read_text", return_value=code), + patch.object(Path, "write_text") as mock_write, + ): + + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + + written = mock_write.call_args[0][0] + assert written.strip() == expected_code.strip() + + +def test_print_statement_preservation(refactorer): + """Tests refactoring of print statements with method chains.""" + code = textwrap.dedent( + """ + def example(): + text = "Hello" + print(text.strip().lower().title()) + """ + ) + expected_code = textwrap.dedent( + """ + def example(): + text = "Hello" + intermediate_0 = text.strip() + intermediate_1 = intermediate_0.lower() + print(intermediate_1.title()) + """ + ) + + smell = create_smell([4])() + with ( + patch.object(Path, "read_text", return_value=code), + patch.object(Path, "write_text") as mock_write, + ): + + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + + written = mock_write.call_args[0][0] + assert written.strip() == expected_code.strip() + + +def test_nested_method_chains(refactorer): + """Tests refactoring of nested method chains.""" + code = textwrap.dedent( + """ + def example(): + result = get_object().config().settings().load() + """ + ) + expected_code = textwrap.dedent( + """ + def example(): + intermediate_0 = get_object() + intermediate_1 = intermediate_0.config() + intermediate_2 = intermediate_1.settings() + result = intermediate_2.load() + """ + ) + + smell = create_smell([3])() + with ( + patch.object(Path, "read_text", return_value=code), + patch.object(Path, "write_text") as mock_write, + ): + + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + + written = mock_write.call_args[0][0] + assert written.strip() == expected_code.strip() diff --git a/tests/refactorers/test_long_parameter_list_refactor.py b/tests/refactorers/test_long_parameter_list_refactor.py new file mode 100644 index 00000000..ad26dcea --- /dev/null +++ b/tests/refactorers/test_long_parameter_list_refactor.py @@ -0,0 +1,380 @@ +import pytest +import textwrap + +from ecooptimizer.refactorers.concrete.long_parameter_list import LongParameterListRefactorer +from ecooptimizer.data_types import LPLSmell, Occurence +from ecooptimizer.utils.smell_enums import PylintSmell + + +@pytest.fixture +def refactorer(): + return LongParameterListRefactorer() + + +def create_smell(occurences: list[int]): + """Factory function to create a smell object""" + + def _create(): + return LPLSmell( + path="fake.py", + module="some_module", + obj=None, + type="refactor", + symbol="too-many-arguments", + message="Too many arguments (8/6)", + messageId=PylintSmell.LONG_PARAMETER_LIST.value, + confidence="UNDEFINED", + occurences=[ + Occurence(line=occ, endLine=999, column=999, endColumn=999) for occ in occurences + ], + ) + + return _create + + +def test_lpl_constructor_1(refactorer, source_files): + """Test for constructor with 8 params all used, mix of keyword and positions params""" + + test_dir = source_files / "temp_test_lpl" + test_dir.mkdir(exist_ok=True) + + test_file = test_dir / "fake.py" + + code = textwrap.dedent("""\ + class UserDataProcessor: + def __init__(self, user_id, username, email, preferences, timezone_config, language, notification_settings, is_active): + self.user_id = user_id + self.username = username + self.email = email + self.preferences = preferences + self.timezone_config = timezone_config + self.language = language + self.notification_settings = notification_settings + self.is_active = is_active + user4 = UserDataProcessor(2, "johndoe", "johndoe@example.com", {"theme": "dark"}, "UTC", language="en", notification_settings=False, is_active=True) + """) + + expected_modified_code = textwrap.dedent("""\ + class DataParams___init___2: + def __init__(self, user_id, username, email, preferences, language, is_active): + self.user_id = user_id + self.username = username + self.email = email + self.preferences = preferences + self.language = language + self.is_active = is_active + class ConfigParams___init___2: + def __init__(self, timezone_config, notification_settings): + self.timezone_config = timezone_config + self.notification_settings = notification_settings + class UserDataProcessor: + def __init__(self, data_params, config_params): + self.user_id = data_params.user_id + self.username = data_params.username + self.email = data_params.email + self.preferences = data_params.preferences + self.timezone_config = config_params.timezone_config + self.language = data_params.language + self.notification_settings = config_params.notification_settings + self.is_active = data_params.is_active + user4 = UserDataProcessor(DataParams___init___2(2, "johndoe", "johndoe@example.com", {"theme": "dark"}, language = "en", is_active = True), ConfigParams___init___2("UTC", notification_settings = False)) + """) + test_file.write_text(code) + smell = create_smell([2])() + refactorer.refactor(test_file, test_dir, smell, test_file) + + modified_code = test_file.read_text() + assert modified_code.strip() == expected_modified_code.strip() + + # cleanup after test + test_file.unlink() + test_dir.rmdir() + + +def test_lpl_constructor_2(refactorer, source_files): + """Test for constructor with 8 params 1 unused, mix of keyword and positions params""" + + test_dir = source_files / "temp_test_lpl" + test_dir.mkdir(parents=True, exist_ok=True) + + test_file = test_dir / "fake.py" + + code = textwrap.dedent("""\ + class UserDataProcessor: + # 8 parameters (1 unused) + def __init__(self, user_id, username, email, preferences, timezone_config, region, notification_settings=True, theme="light"): + self.user_id = user_id + self.username = username + self.email = email + self.preferences = preferences + self.timezone_config = timezone_config + self.region = region + self.notification_settings = notification_settings + # theme is unused + user5 = UserDataProcessor(2, "janedoe", "janedoe@example.com", {"theme": "light"}, "UTC", region="en", notification_settings=False) + """) + + expected_modified_code = textwrap.dedent("""\ + class DataParams___init___3: + def __init__(self, user_id, username, email, preferences, region): + self.user_id = user_id + self.username = username + self.email = email + self.preferences = preferences + self.region = region + class ConfigParams___init___3: + def __init__(self, timezone_config, notification_settings = True): + self.timezone_config = timezone_config + self.notification_settings = notification_settings + class UserDataProcessor: + # 8 parameters (1 unused) + def __init__(self, data_params, config_params): + self.user_id = data_params.user_id + self.username = data_params.username + self.email = data_params.email + self.preferences = data_params.preferences + self.timezone_config = config_params.timezone_config + self.region = data_params.region + self.notification_settings = config_params.notification_settings + # theme is unused + user5 = UserDataProcessor(DataParams___init___3(2, "janedoe", "janedoe@example.com", {"theme": "light"}, region = "en"), ConfigParams___init___3("UTC", notification_settings = False)) + """) + test_file.write_text(code) + smell = create_smell([3])() + refactorer.refactor(test_file, test_dir, smell, test_file) + + modified_code = test_file.read_text() + print("***************************************") + print(modified_code.strip()) + print("***************************************") + print(expected_modified_code.strip()) + print("***************************************") + assert modified_code.strip() == expected_modified_code.strip() + + # cleanup after test + test_file.unlink() + test_dir.rmdir() + + +def test_lpl_instance(refactorer, source_files): + """Test for instance method 8 params 0 unused""" + + test_dir = source_files / "temp_test_lpl" + test_dir.mkdir(parents=True, exist_ok=True) + + test_file = test_dir / "fake.py" + + code = textwrap.dedent("""\ + class UserDataProcessor6: + # 8 parameters (4 unused) + def __init__(self, user_id, username, email, preferences, timezone_config, backup_config=None, display_theme=None, active_status=None): + self.user_id = user_id + self.username = username + self.email = email + self.preferences = preferences + # timezone_config, backup_config, display_theme, active_status are unused + # 8 parameters (no unused) + def bulk_update(self, username, email, preferences, timezone_config, region, notification_settings, theme="light", is_active=None): + self.username = username + self.email = email + self.preferences = preferences + self.settings["timezone"] = timezone_config + self.settings["region"] = region + self.settings["notifications"] = notification_settings + self.settings["theme"] = theme + self.settings["is_active"] = is_active + user6 = UserDataProcessor6(3, "janedoe", "janedoe@example.com", {"theme": "blue"}) + user6.bulk_update("johndoe", "johndoe@example.com", {"theme": "dark"}, "UTC", "en", True, "dark", is_active=True) + """) + + expected_modified_code = textwrap.dedent("""\ + class DataParams_bulk_update_10: + def __init__(self, username, email, preferences, region, theme = "light", is_active = None): + self.username = username + self.email = email + self.preferences = preferences + self.region = region + self.theme = theme + self.is_active = is_active + class ConfigParams_bulk_update_10: + def __init__(self, timezone_config, notification_settings): + self.timezone_config = timezone_config + self.notification_settings = notification_settings + class UserDataProcessor6: + # 8 parameters (4 unused) + def __init__(self, user_id, username, email, preferences, timezone_config, backup_config=None, display_theme=None, active_status=None): + self.user_id = user_id + self.username = username + self.email = email + self.preferences = preferences + # timezone_config, backup_config, display_theme, active_status are unused + # 8 parameters (no unused) + def bulk_update(self, data_params, config_params): + self.username = data_params.username + self.email = data_params.email + self.preferences = data_params.preferences + self.settings["timezone"] = config_params.timezone_config + self.settings["region"] = data_params.region + self.settings["notifications"] = config_params.notification_settings + self.settings["theme"] = data_params.theme + self.settings["is_active"] = data_params.is_active + user6 = UserDataProcessor6(3, "janedoe", "janedoe@example.com", {"theme": "blue"}) + user6.bulk_update(DataParams_bulk_update_10("johndoe", "johndoe@example.com", {"theme": "dark"}, "en", "dark", is_active = True), ConfigParams_bulk_update_10("UTC", True)) + """) + test_file.write_text(code) + smell = create_smell([10])() + refactorer.refactor(test_file, test_dir, smell, test_file) + + modified_code = test_file.read_text() + assert modified_code.strip() == expected_modified_code.strip() + + # cleanup after test + test_file.unlink() + test_dir.rmdir() + + +def test_lpl_static(refactorer, source_files): + """Test for static method for 8 params 1 unused, default values""" + + test_dir = source_files / "temp_test_lpl" + test_dir.mkdir(parents=True, exist_ok=True) + + test_file = test_dir / "fake.py" + + code = textwrap.dedent("""\ + class UserDataProcessor6: + # 8 parameters (4 unused) + def __init__(self, user_id, username, email, preferences, timezone_config, backup_config=None, display_theme=None, active_status=None): + self.user_id = user_id + self.username = username + self.email = email + self.preferences = preferences + # timezone_config, backup_config, display_theme, active_status are unused + # 8 parameters (1 unused) + @staticmethod + def generate_report_partial(username, email, preferences, timezone_config, region, notification_settings, theme, active_status=None): + report = {} + report.username= username + report.email = email + report.preferences = preferences + report.timezone = timezone_config + report.region = region + report.notifications = notification_settings + report.active_status = active_status + #theme is unused + return report + UserDataProcessor6.generate_report_partial("janedoe", "janedoe@example.com", {"theme": "light"}, "PST", "en", False, theme="green", active_status="online") + """) + + expected_modified_code = textwrap.dedent("""\ + class DataParams_generate_report_partial_11: + def __init__(self, username, email, preferences, region, active_status = None): + self.username = username + self.email = email + self.preferences = preferences + self.region = region + self.active_status = active_status + class ConfigParams_generate_report_partial_11: + def __init__(self, timezone_config, notification_settings): + self.timezone_config = timezone_config + self.notification_settings = notification_settings + class UserDataProcessor6: + # 8 parameters (4 unused) + def __init__(self, user_id, username, email, preferences, timezone_config, backup_config=None, display_theme=None, active_status=None): + self.user_id = user_id + self.username = username + self.email = email + self.preferences = preferences + # timezone_config, backup_config, display_theme, active_status are unused + # 8 parameters (1 unused) + @staticmethod + def generate_report_partial(data_params, config_params): + report = {} + report.username= data_params.username + report.email = data_params.email + report.preferences = data_params.preferences + report.timezone = config_params.timezone_config + report.region = data_params.region + report.notifications = config_params.notification_settings + report.active_status = data_params.active_status + #theme is unused + return report + UserDataProcessor6.generate_report_partial(DataParams_generate_report_partial_11("janedoe", "janedoe@example.com", {"theme": "light"}, "en", active_status = "online"), ConfigParams_generate_report_partial_11("PST", False)) + """) + test_file.write_text(code) + smell = create_smell([11])() + refactorer.refactor(test_file, test_dir, smell, test_file) + + modified_code = test_file.read_text() + print("***************************************") + print(modified_code.strip()) + print("***************************************") + print(expected_modified_code.strip()) + print("***************************************") + assert modified_code.strip() == expected_modified_code.strip() + + # cleanup after test + test_file.unlink() + test_dir.rmdir() + + +def test_lpl_standalone(refactorer, source_files): + """Test for standalone function 8 params 1 unused keyword arguments and default values""" + + test_dir = source_files / "temp_test_lpl" + test_dir.mkdir(parents=True, exist_ok=True) + + test_file = test_dir / "fake.py" + + code = textwrap.dedent("""\ + # 8 parameters (1 unused) + def create_partial_report(user_id, username, email, preferences, timezone_config, language, notification_settings, active_status=None): + report = {} + report.user_id= user_id + report.username = username + report.email = email + report.preferences = preferences + report.timezone = timezone_config + report.language = language + report.notifications = notification_settings + # active_status is unused + return report + create_partial_report(2, "janedoe", "janedoe@example.com", {"theme": "light"}, "PST", "en", notification_settings=False) + """) + + expected_modified_code = textwrap.dedent("""\ + # 8 parameters (1 unused) + class DataParams_create_partial_report_2: + def __init__(self, user_id, username, email, preferences, language): + self.user_id = user_id + self.username = username + self.email = email + self.preferences = preferences + self.language = language + class ConfigParams_create_partial_report_2: + def __init__(self, timezone_config, notification_settings): + self.timezone_config = timezone_config + self.notification_settings = notification_settings + def create_partial_report(data_params, config_params): + report = {} + report.user_id= data_params.user_id + report.username = data_params.username + report.email = data_params.email + report.preferences = data_params.preferences + report.timezone = config_params.timezone_config + report.language = data_params.language + report.notifications = config_params.notification_settings + # active_status is unused + return report + create_partial_report(DataParams_create_partial_report_2(2, "janedoe", "janedoe@example.com", {"theme": "light"}, "en"), ConfigParams_create_partial_report_2("PST", notification_settings = False)) + """) + test_file.write_text(code) + smell = create_smell([2])() + refactorer.refactor(test_file, test_dir, smell, test_file) + + modified_code = test_file.read_text() + assert modified_code.strip() == expected_modified_code.strip() + + # cleanup after test + test_file.unlink() + test_dir.rmdir() diff --git a/tests/refactorers/test_member_ignoring_method.py b/tests/refactorers/test_member_ignoring_method.py new file mode 100644 index 00000000..1531049b --- /dev/null +++ b/tests/refactorers/test_member_ignoring_method.py @@ -0,0 +1,364 @@ +import pytest + +import textwrap +from pathlib import Path + +from ecooptimizer.refactorers.concrete.member_ignoring_method import MakeStaticRefactorer +from ecooptimizer.data_types import MIMSmell, Occurence +from ecooptimizer.utils.smell_enums import PylintSmell + + +@pytest.fixture +def refactorer(): + return MakeStaticRefactorer() + + +def create_smell(occurences: list[int], obj: str): + """Factory function to create a smell object""" + + def _create(): + return MIMSmell( + path="fake.py", + module="some_module", + obj=obj, + type="refactor", + symbol="no-self-use", + message="Method could be a function", + messageId=PylintSmell.NO_SELF_USE.value, + confidence="INFERENCE", + occurences=[ + Occurence( + line=occ, + endLine=999, + column=999, + endColumn=999, + ) + for occ in occurences + ], + additionalInfo=None, + ) + + return _create + + +def test_mim_basic_case(source_files, refactorer): + """ + Tests that the member ignoring method refactorer: + - Adds @staticmethod decorator. + - Removes 'self' from method signature. + - Updates calls in external files. + """ + + # --- File 1: Defines the method --- + test_dir = Path(source_files, "temp_basic_mim") + test_dir.mkdir(exist_ok=True) + + file1 = test_dir / "class_def.py" + file1.write_text( + textwrap.dedent("""\ + class Example: + def __init__(self): + self.attr = "something" + def mim_method(self, x): + return x * 2 + + example = Example() + num = example.mim_method(5) + """) + ) + + # --- File 2: Calls the method --- + file2 = test_dir / "caller.py" + file2.write_text( + textwrap.dedent("""\ + from .class_def import Example + example = Example() + result = example.mim_method(5) + """) + ) + + smell = create_smell(occurences=[4], obj="Example.mim_method")() + + refactorer.refactor(file1, test_dir, smell, Path("fake.py")) + + # --- Expected Result for File 1 --- + expected_file1 = textwrap.dedent("""\ + class Example: + def __init__(self): + self.attr = "something" + @staticmethod + def mim_method(x): + return x * 2 + + example = Example() + num = Example.mim_method(5) + """) + + # --- Expected Result for File 2 --- + expected_file2 = textwrap.dedent("""\ + from .class_def import Example + example = Example() + result = Example.mim_method(5) + """) + + # Check if the refactoring worked + assert file1.read_text().strip() == expected_file1.strip() + assert file2.read_text().strip() == expected_file2.strip() + + +def test_mim_inheritence_case(source_files, refactorer): + """ + Tests that calls originating from a subclass instance are also refactored. + """ + + # --- File 1: Defines the method --- + test_dir = Path(source_files, "temp_inherited_mim") + test_dir.mkdir(exist_ok=True) + + file1 = test_dir / "class_def.py" + file1.write_text( + textwrap.dedent("""\ + class Example: + def __init__(self): + self.attr = "something" + def mim_method(self, x): + return x * 2 + + class SubExample(Example): + pass + + example = SubExample() + num = example.mim_method(5) + """) + ) + + # --- File 2: Calls the method --- + file2 = test_dir / "caller.py" + file2.write_text( + textwrap.dedent("""\ + from .class_def import SubExample + example = SubExample() + result = example.mim_method(5) + """) + ) + + smell = create_smell(occurences=[4], obj="Example.mim_method")() + + refactorer.refactor(file1, test_dir, smell, Path("fake.py")) + + # --- Expected Result for File 1 --- + expected_file1 = textwrap.dedent("""\ + class Example: + def __init__(self): + self.attr = "something" + @staticmethod + def mim_method(x): + return x * 2 + + class SubExample(Example): + pass + + example = SubExample() + num = SubExample.mim_method(5) + """) + + # --- Expected Result for File 2 --- + expected_file2 = textwrap.dedent("""\ + from .class_def import SubExample + example = SubExample() + result = SubExample.mim_method(5) + """) + + # Check if the refactoring worked + assert file1.read_text().strip() == expected_file1.strip() + assert file2.read_text().strip() == expected_file2.strip() + + +def test_mim_inheritence_seperate_subclass(source_files, refactorer): + """ + Tests that subclasses declared in files other than the initial one are detected. + """ + + # --- File 1: Defines the method --- + test_dir = Path(source_files, "temp_inherited_ss_mim") + test_dir.mkdir(exist_ok=True) + + file1 = test_dir / "class_def.py" + file1.write_text( + textwrap.dedent("""\ + class Example: + def __init__(self): + self.attr = "something" + def mim_method(self, x): + return x * 2 + + example = Example() + num = example.mim_method(5) + """) + ) + + # --- File 2: Calls the method --- + file2 = test_dir / "caller.py" + file2.write_text( + textwrap.dedent("""\ + from .class_def import Example + + class SubExample(Example): + pass + + example = SubExample() + result = example.mim_method(5) + """) + ) + + smell = create_smell(occurences=[4], obj="Example.mim_method")() + + refactorer.refactor(file1, test_dir, smell, Path("fake.py")) + + # --- Expected Result for File 1 --- + expected_file1 = textwrap.dedent("""\ + class Example: + def __init__(self): + self.attr = "something" + @staticmethod + def mim_method(x): + return x * 2 + + example = Example() + num = Example.mim_method(5) + """) + + # --- Expected Result for File 2 --- + expected_file2 = textwrap.dedent("""\ + from .class_def import Example + + class SubExample(Example): + pass + + example = SubExample() + result = SubExample.mim_method(5) + """) + + # Check if the refactoring worked + assert file1.read_text().strip() == expected_file1.strip() + assert file2.read_text().strip() == expected_file2.strip() + + +def test_mim_inheritence_subclass_method_override(source_files, refactorer): + """ + Tests that calls to the mim method from subclass instance with method override are NOT changed. + """ + + # --- File 1: Defines the method --- + test_dir = Path(source_files, "temp_inherited_override_mim") + test_dir.mkdir(exist_ok=True) + + file1 = test_dir / "class_def.py" + file1.write_text( + textwrap.dedent("""\ + class Example: + def __init__(self): + self.attr = "something" + def mim_method(self, x): + return x * 2 + + class SubExample(Example): + def mim_method(self, x): + return x * 3 + + example = Example() + num = example.mim_method(5) + """) + ) + + # --- File 2: Calls the method --- + file2 = test_dir / "caller.py" + file2.write_text( + textwrap.dedent("""\ + from .class_def import SubExample + example = SubExample() + result = example.mim_method(5) + """) + ) + + smell = create_smell(occurences=[4], obj="Example.mim_method")() + + refactorer.refactor(file1, test_dir, smell, Path("fake.py")) + + # --- Expected Result for File 1 --- + expected_file1 = textwrap.dedent("""\ + class Example: + def __init__(self): + self.attr = "something" + @staticmethod + def mim_method(x): + return x * 2 + + class SubExample(Example): + def mim_method(self, x): + return x * 3 + + example = Example() + num = Example.mim_method(5) + """) + + # --- Expected Result for File 2 --- + expected_file2 = textwrap.dedent("""\ + from .class_def import SubExample + example = SubExample() + result = example.mim_method(5) + """) + + # Check if the refactoring worked + assert file1.read_text().strip() == expected_file1.strip() + assert file2.read_text().strip() == expected_file2.strip() + + +def test_mim_type_hint_inferrence(source_files, refactorer): + """ + Tests that type hints declaring and instance type are detected. + """ + + # --- File 1: Defines the method --- + test_dir = Path(source_files, "temp_mim_type_hint_mim") + test_dir.mkdir(exist_ok=True) + + file1 = test_dir / "class_def.py" + file1.write_text( + textwrap.dedent("""\ + class Example: + def __init__(self): + self.attr = "something" + def mim_method(self, x): + return x * 2 + + def test(example: Example): + print(example.mim_method(3)) + + example = Example() + num = example.mim_method(5) + """) + ) + + smell = create_smell(occurences=[4], obj="Example.mim_method")() + + refactorer.refactor(file1, test_dir, smell, Path("fake.py")) + + # --- Expected Result for File 1 --- + expected_file1 = textwrap.dedent("""\ + class Example: + def __init__(self): + self.attr = "something" + @staticmethod + def mim_method(x): + return x * 2 + + def test(example: Example): + print(Example.mim_method(3)) + + example = Example() + num = Example.mim_method(5) + """) + + # Check if the refactoring worked + assert file1.read_text().strip() == expected_file1.strip() diff --git a/tests/refactorers/test_repeated_calls_refactor.py b/tests/refactorers/test_repeated_calls_refactor.py new file mode 100644 index 00000000..162d680d --- /dev/null +++ b/tests/refactorers/test_repeated_calls_refactor.py @@ -0,0 +1,249 @@ +import pytest +import textwrap +from pathlib import Path +from ecooptimizer.refactorers.concrete.repeated_calls import CacheRepeatedCallsRefactorer +from ecooptimizer.data_types import CRCSmell, Occurence, CRCInfo + + +@pytest.fixture +def refactorer(): + return CacheRepeatedCallsRefactorer() + + +def create_smell(occurences: list[dict[str, int]], call_string: str, repetitions: int): + """Factory function to create a CRCSmell object with accurate metadata.""" + + def _create(): + return CRCSmell( + path="fake.py", + module="some_module", + obj=None, + type="performance", + symbol="cached-repeated-calls", + message=f"Repeated function call detected ({repetitions}/{repetitions}). Consider caching the result: {call_string}", + messageId="CRC001", + confidence="HIGH" if repetitions > 2 else "MEDIUM", + occurences=[ + Occurence( + line=occ["line"], + endLine=occ["endLine"], + column=occ["column"], + endColumn=occ["endColumn"], + ) + for occ in occurences + ], + additionalInfo=CRCInfo( + repetitions=repetitions, + callString=call_string, + ), + ) + + return _create + + +def test_crc_basic_case(source_files, refactorer): + """ + Tests that repeated function calls are cached properly. + """ + test_dir = Path(source_files, "temp_crc_basic") + test_dir.mkdir(exist_ok=True) + + file1 = test_dir / "crc_def.py" + file1.write_text( + textwrap.dedent(""" + def expensive_function(x): + return x * x + + def test_case(): + result1 = expensive_function(42) + result2 = expensive_function(42) + result3 = expensive_function(42) + return result1 + result2 + result3 + """) + ) + + smell = create_smell( + occurences=[ + {"line": 6, "endLine": 6, "column": 14, "endColumn": 38}, + {"line": 7, "endLine": 7, "column": 14, "endColumn": 38}, + {"line": 8, "endLine": 8, "column": 14, "endColumn": 38}, + ], + call_string="expensive_function(42)", + repetitions=3, + )() + refactorer.refactor(file1, test_dir, smell, Path("fake.py")) + + expected_file1 = textwrap.dedent(""" + def expensive_function(x): + return x * x + + def test_case(): + cached_expensive_function = expensive_function(42) + result1 = cached_expensive_function + result2 = cached_expensive_function + result3 = cached_expensive_function + return result1 + result2 + result3 + """) + + assert file1.read_text().strip() == expected_file1.strip() + + +def test_crc_method_calls(source_files, refactorer): + """ + Tests that repeated method calls on an object are cached properly. + """ + test_dir = Path(source_files, "temp_crc_method") + test_dir.mkdir(exist_ok=True) + + file1 = test_dir / "crc_def.py" + file1.write_text( + textwrap.dedent(""" + class Demo: + def __init__(self, value): + self.value = value + def compute(self): + return self.value * 2 + + def test_case(): + obj = Demo(3) + result1 = obj.compute() + result2 = obj.compute() + return result1 + result2 + """) + ) + + smell = create_smell( + occurences=[ + {"line": 10, "endLine": 10, "column": 14, "endColumn": 28}, + {"line": 11, "endLine": 11, "column": 14, "endColumn": 28}, + ], + call_string="obj.compute()", + repetitions=2, + )() + refactorer.refactor(file1, test_dir, smell, Path("fake.py")) + + expected_file1 = textwrap.dedent(""" + class Demo: + def __init__(self, value): + self.value = value + def compute(self): + return self.value * 2 + + def test_case(): + obj = Demo(3) + cached_obj_compute = obj.compute() + result1 = cached_obj_compute + result2 = cached_obj_compute + return result1 + result2 + """) + + assert file1.read_text().strip() == expected_file1.strip() + + +def test_crc_instance_method_repeated(source_files, refactorer): + """ + Tests that repeated method calls on the same object instance are cached. + """ + test_dir = Path(source_files, "temp_crc_instance_method") + test_dir.mkdir(exist_ok=True) + + file1 = test_dir / "crc_def.py" + file1.write_text( + textwrap.dedent(""" + class Demo: + def __init__(self, value): + self.value = value + def compute(self): + return self.value * 2 + + def test_case(): + demo1 = Demo(1) + demo2 = Demo(2) + result1 = demo1.compute() + result2 = demo2.compute() + result3 = demo1.compute() + return result1 + result2 + result3 + """) + ) + + smell = create_smell( + occurences=[ + {"line": 11, "endLine": 11, "column": 14, "endColumn": 28}, + {"line": 13, "endLine": 13, "column": 14, "endColumn": 28}, + ], + call_string="demo1.compute()", + repetitions=2, + )() + refactorer.refactor(file1, test_dir, smell, Path("fake.py")) + + expected_file1 = textwrap.dedent(""" + class Demo: + def __init__(self, value): + self.value = value + def compute(self): + return self.value * 2 + + def test_case(): + demo1 = Demo(1) + cached_demo1_compute = demo1.compute() + demo2 = Demo(2) + result1 = cached_demo1_compute + result2 = demo2.compute() + result3 = cached_demo1_compute + return result1 + result2 + result3 + """) + + assert file1.read_text().strip() == expected_file1.strip() + + +def test_crc_with_docstrigs(source_files, refactorer): + """ + Tests that repeated function calls are cached properly when docstrings present. + """ + test_dir = Path(source_files, "temp_crc_docstring") + test_dir.mkdir(exist_ok=True) + + file1 = test_dir / "crc_def.py" + file1.write_text( + textwrap.dedent(''' + def expensive_function(x): + return x * x + + def test_case(): + """ + Example docstring + """ + result1 = expensive_function(100) + result2 = expensive_function(100) + result3 = expensive_function(42) + return result1 + result2 + result3 + ''') + ) + + smell = create_smell( + occurences=[ + {"line": 9, "endLine": 9, "column": 14, "endColumn": 38}, + {"line": 10, "endLine": 10, "column": 14, "endColumn": 38}, + {"line": 11, "endLine": 11, "column": 14, "endColumn": 38}, + ], + call_string="expensive_function(100)", + repetitions=3, + )() + refactorer.refactor(file1, test_dir, smell, Path("fake.py")) + + expected_file1 = textwrap.dedent(''' + def expensive_function(x): + return x * x + + def test_case(): + """ + Example docstring + """ + cached_expensive_function = expensive_function(100) + result1 = cached_expensive_function + result2 = cached_expensive_function + result3 = expensive_function(42) + return result1 + result2 + result3 + ''') + + assert file1.read_text().strip() == expected_file1.strip() diff --git a/tests/refactorers/test_str_concat_in_loop_refactor.py b/tests/refactorers/test_str_concat_in_loop_refactor.py new file mode 100644 index 00000000..ce75616a --- /dev/null +++ b/tests/refactorers/test_str_concat_in_loop_refactor.py @@ -0,0 +1,439 @@ +import pytest +from unittest.mock import patch + +from pathlib import Path + +from ecooptimizer.refactorers.concrete.str_concat_in_loop import UseListAccumulationRefactorer +from ecooptimizer.data_types import SCLInfo, Occurence, SCLSmell +from ecooptimizer.utils.smell_enums import CustomSmell + + +@pytest.fixture +def refactorer(): + return UseListAccumulationRefactorer() + + +def create_smell(occurences: list[int], concat_target: str, inner_loop_line: int): + """Factory function to create a smell object""" + + def _create(): + return SCLSmell( + path="fake.py", + module="some_module", + obj=None, + type="performance", + symbol="string-concat-loop", + message="String concatenation inside loop detected", + messageId=CustomSmell.STR_CONCAT_IN_LOOP.value, + confidence="UNDEFINED", + occurences=[ + Occurence( + line=occ, + endLine=999, + column=999, + endColumn=999, + ) + for occ in occurences + ], + additionalInfo=SCLInfo( + concatTarget=concat_target, + innerLoopLine=inner_loop_line, + ), + ) + + return _create + + +@pytest.mark.parametrize("val", [("''"), ('""'), ("str()")]) +def test_empty_initial_var(refactorer, val): + """Test for inital concat var being empty.""" + code = f""" + def example(): + result = {val} + for i in range(5): + result += str(i) + return result + """ + smell = create_smell(occurences=[5], concat_target="result", inner_loop_line=4)() + + with ( + patch.object(Path, "read_text", return_value=code), + patch.object(Path, "write_text") as mock_write_text, + ): + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + + mock_write_text.assert_called_once() # Ensure write_text was called once + written_code = mock_write_text.call_args[0][0] # The first argument is the modified code + + # Check that the modified code is correct + assert "result = []\n" in written_code + assert f"result = {val}\n" not in written_code + + assert "result.append(str(i))\n" in written_code + + assert "result = ''.join(result)\n" in written_code + + +def test_non_empty_initial_name_var_not_referenced(refactorer): + """Test for initial concat value being none empty.""" + code = """ + def example(): + result = "Hello" + for i in range(5): + result += str(i) + return result + """ + smell = create_smell(occurences=[5], concat_target="result", inner_loop_line=4)() + + with ( + patch.object(Path, "read_text", return_value=code), + patch.object(Path, "write_text") as mock_write_text, + ): + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + + mock_write_text.assert_called_once() # Ensure write_text was called once + written_code = mock_write_text.call_args[0][0] # The first argument is the modified code + + # Check that the modified code is correct + assert "result = ['Hello']\n" in written_code + assert 'result = "Hello"\n' not in written_code + + assert "result.append(str(i))\n" in written_code + + assert "result = ''.join(result)\n" in written_code + + +def test_non_empty_initial_name_var_referenced(refactorer): + """Test for initialization when var is referenced after but before the loop start.""" + code = """ + def example(): + result = "Hello" + backup = result + for i in range(5): + result += str(i) + return result + """ + smell = create_smell(occurences=[6], concat_target="result", inner_loop_line=5)() + + with ( + patch.object(Path, "read_text", return_value=code), + patch.object(Path, "write_text") as mock_write_text, + ): + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + + mock_write_text.assert_called_once() # Ensure write_text was called once + written_code = mock_write_text.call_args[0][0] # The first argument is the modified code + + # Check that the modified code is correct + assert 'result = "Hello"\n' in written_code + assert "result = [result]\n" in written_code + + assert "result.append(str(i))\n" in written_code + + assert "result = ''.join(result)\n" in written_code + + +def test_initial_not_name_var(refactorer): + """Test that none name vars are initialized to a temp list""" + code = """ + def example(): + result = {"key" : "Hello"} + for i in range(5): + result["key"] += str(i) + return result + """ + smell = create_smell(occurences=[5], concat_target='result["key"]', inner_loop_line=4)() + + with ( + patch.object(Path, "read_text", return_value=code), + patch.object(Path, "write_text") as mock_write_text, + ): + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + + list_name = refactorer.generate_temp_list_name() + + mock_write_text.assert_called_once() # Ensure write_text was called once + written_code = mock_write_text.call_args[0][0] # The first argument is the modified code + + # Check that the modified code is correct + assert 'result = {"key" : "Hello"}\n' in written_code + assert f'{list_name} = [result["key"]]\n' in written_code + + assert f"{list_name}.append(str(i))\n" in written_code + + assert f"result[\"key\"] = ''.join({list_name})\n" in written_code + + +def test_initial_not_in_scope(refactorer): + """Test for refactoring of a concat variable not initialized in the same scope.""" + code = """ + def example(result: str): + for i in range(5): + result += str(i) + return result + """ + smell = create_smell(occurences=[4], concat_target="result", inner_loop_line=3)() + + with ( + patch.object(Path, "read_text", return_value=code), + patch.object(Path, "write_text") as mock_write_text, + ): + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + + mock_write_text.assert_called_once() # Ensure write_text was called once + written_code = mock_write_text.call_args[0][0] # The first argument is the modified code + + # Check that the modified code is correct + assert "result = [result]\n" in written_code + + assert "result.append(str(i))\n" in written_code + + assert "result = ''.join(result)\n" in written_code + + +def test_insert_on_prefix(refactorer): + """Ensure insert(0) is used for prefix concatenation""" + code = """ + def example(): + result = "" + for i in range(5): + result = str(i) + result + return result + """ + smell = create_smell(occurences=[5], concat_target="result", inner_loop_line=4)() + + with ( + patch.object(Path, "read_text", return_value=code), + patch.object(Path, "write_text") as mock_write_text, + ): + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + + mock_write_text.assert_called_once() # Ensure write_text was called once + written_code = mock_write_text.call_args[0][0] # The first argument is the modified code + + assert "result = []\n" in written_code + assert 'result = ""\n' not in written_code + + assert "result.insert(0, str(i))\n" in written_code + + assert "result = ''.join(result)\n" in written_code + + +def test_concat_with_prefix_and_suffix(refactorer): + """Test for proper refactoring of a concatenation containing both a prefix and suffix concat.""" + code = """ + def example(): + result = "" + for i in range(5): + result = str(i) + result + str(i) + return result + """ + smell = create_smell(occurences=[5], concat_target="result", inner_loop_line=4)() + + with ( + patch.object(Path, "read_text", return_value=code), + patch.object(Path, "write_text") as mock_write_text, + ): + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + + mock_write_text.assert_called_once() # Ensure write_text was called once + written_code = mock_write_text.call_args[0][0] # The first argument is the modified code + + assert "result = []\n" in written_code + assert 'result = ""\n' not in written_code + + assert "result.insert(0, str(i))\n" in written_code + assert "result.append(str(i))\n" in written_code + + assert "result = ''.join(result)\n" in written_code + + +def test_multiple_concat_occurrences(refactorer): + """Test for multiple successive concatenations in the same loop for 1 smell.""" + code = """ + def example(): + result = "" + fruits = ["apple", "banana", "orange", "kiwi"] + for fruit in fruits: + result += fruit + result = fruit + result + return result + """ + smell = create_smell(occurences=[6, 7], concat_target="result", inner_loop_line=5)() + + with ( + patch.object(Path, "read_text", return_value=code), + patch.object(Path, "write_text") as mock_write_text, + ): + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + + mock_write_text.assert_called_once() # Ensure write_text was called once + written_code = mock_write_text.call_args[0][0] # The first argument is the modified code + + assert "result = []\n" in written_code + assert 'result = ""\n' not in written_code + + assert "result.append(fruit)\n" in written_code + assert "result.insert(0, fruit)\n" in written_code + + assert "result = ''.join(result)\n" in written_code + + +def test_nested_concat(refactorer): + """Test for nested concat in loop.""" + code = """ + def example(): + result = "" + for i in range(5): + for j in range(6): + result = str(i) + result + str(j) + return result + """ + smell = create_smell(occurences=[6], concat_target="result", inner_loop_line=4)() + + with ( + patch.object(Path, "read_text", return_value=code), + patch.object(Path, "write_text") as mock_write_text, + ): + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + + mock_write_text.assert_called_once() # Ensure write_text was called once + written_code = mock_write_text.call_args[0][0] # The first argument is the modified code + + assert "result = []\n" in written_code + assert 'result = ""\n' not in written_code + + assert "result.append(str(j))\n" in written_code + assert "result.insert(0, str(i))\n" in written_code + + assert "result = ''.join(result)\n" in written_code + + +def test_multi_occurrence_nested_concat(refactorer): + """Test for multiple occurrences of a same smell at different loop levels.""" + code = """ + def example(): + result = "" + for i in range(5): + result += str(i) + for j in range(6): + result = result + str(j) + return result + """ + smell = create_smell(occurences=[5, 7], concat_target="result", inner_loop_line=4)() + + with ( + patch.object(Path, "read_text", return_value=code), + patch.object(Path, "write_text") as mock_write_text, + ): + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + + mock_write_text.assert_called_once() # Ensure write_text was called once + written_code = mock_write_text.call_args[0][0] # The first argument is the modified code + + assert "result = []\n" in written_code + assert 'result = ""\n' not in written_code + + assert "result.append(str(i))\n" in written_code + assert "result.append(str(j))\n" in written_code + + assert "result = ''.join(result)\n" in written_code + + +def test_reassignment(refactorer): + """Ensure list is reset to new val when reassigned inside the loop.""" + code = """ + class Test: + def __init__(self): + self.text = "" + obj = Test() + for word in ["bug", "warning", "Hello", "World"]: + obj.text += word + if word == "warning": + obj.text = "Well, " + """ + smell = create_smell(occurences=[7], concat_target="obj.text", inner_loop_line=6)() + + with ( + patch.object(Path, "read_text", return_value=code), + patch.object(Path, "write_text") as mock_write_text, + ): + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + + mock_write_text.assert_called_once() # Ensure write_text was called once + written_code = mock_write_text.call_args[0][0] # The first argument is the modified code + + list_name = refactorer.generate_temp_list_name() + + assert f"{list_name} = [obj.text]\n" in written_code + + assert f"{list_name}.append(word)\n" in written_code + assert f"{list_name} = ['Well, ']\n" in written_code # astroid changes quotes + assert 'obj.text = "Well, "\n' not in written_code + + +@pytest.mark.parametrize("val", [("''"), ('""'), ("str()")]) +def test_reassignment_clears_list(refactorer, val): + """Ensure list is cleared when reassigned inside the loop using clear().""" + code = f""" + class Test: + def __init__(self): + self.text = "" + obj = Test() + for word in ["bug", "warning", "Hello", "World"]: + obj.text += word + if word == "warning": + obj.text = {val} + """ + smell = create_smell(occurences=[7], concat_target="obj.text", inner_loop_line=6)() + + with ( + patch.object(Path, "read_text", return_value=code), + patch.object(Path, "write_text") as mock_write_text, + ): + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + + mock_write_text.assert_called_once() # Ensure write_text was called once + written_code = mock_write_text.call_args[0][0] # The first argument is the modified code + + list_name = refactorer.generate_temp_list_name() + + assert f"{list_name} = [obj.text]\n" in written_code + + assert f"{list_name}.append(word)\n" in written_code + assert f"{list_name}.clear()\n" in written_code + + +def test_no_unrelated_modifications(refactorer): + """Ensure formatting and any comments for unrelated lines are preserved.""" + code = """ + def example(): + print("Hello World") + # This is a comment + result = "" + unrelated_var = 0 + for i in range(5): # This is also a comment + result += str(i) + unrelated_var += i # Yep, you guessed it, comment + return result # Another one here + random = example() # And another one, why not + """ + smell = create_smell(occurences=[8], concat_target="result", inner_loop_line=7)() + + with ( + patch.object(Path, "read_text", return_value=code), + patch.object(Path, "write_text") as mock_write_text, + ): + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + + mock_write_text.assert_called_once() # Ensure write_text was called once + written_code: str = mock_write_text.call_args[0][0] # The first argument is the modified code + + original_lines = code.split("\n") + modified_lines = written_code.split("\n") + + assert all(line_o == line_m for line_o, line_m in zip(original_lines[:4], modified_lines[:4])) + assert all(line_o == line_m for line_o, line_m in zip(original_lines[5:7], modified_lines[5:7])) + assert original_lines[8] == modified_lines[8] + assert original_lines[9] == modified_lines[10] + assert original_lines[10] == modified_lines[11]