Skip to content

Refactor(handler part 13): vae encode#594

Merged
ChuxiJ merged 1 commit intoace-step:mainfrom
1larity:feat/decompose-handler-part10-pr3-vae-encode
Feb 15, 2026
Merged

Refactor(handler part 13): vae encode#594
ChuxiJ merged 1 commit intoace-step:mainfrom
1larity:feat/decompose-handler-part10-pr3-vae-encode

Conversation

@1larity
Copy link
Contributor

@1larity 1larity commented Feb 15, 2026

Summary

Extract VAE encode/tiled encode logic from acestep/handler.py into focused mixins while preserving runtime behaviour.

This continues the handler FD plan by slimming the facade and moving encode responsibilities into dedicated modules.

What changed

Decomposition

  • Added acestep/core/generation/handler/vae_encode.py
    • VaeEncodeMixin
    • tiled_encode(...) orchestration and path selection
  • Added acestep/core/generation/handler/vae_encode_chunks.py
    • VaeEncodeChunksMixin
    • _tiled_encode_gpu(...)
    • _tiled_encode_offload_cpu(...)
  • Wired mixins in:
    • acestep/core/generation/handler/__init__.py
    • acestep/handler.py
  • Removed moved VAE encode implementation block from handler.py facade section.

Tests

  • Added acestep/core/generation/handler/vae_encode_test.py
    • MLX fast-path routing
    • direct encode path
    • offload routing
    • invalid stride validation
    • GPU routing path
    • GPU/offload output parity on deterministic input
    • offload output device assertion

Behavioral parity notes

  • No public API/signature changes intended.
  • Preserves:
    • MLX fast path + fallback semantics
    • chunk sizing/overlap defaults and stride validation
    • offload vs GPU encode routing
    • latent trimming/downsample flow in chunk methods

Validation

  • py_compile passed on changed files.
  • Unit tests are included; environment-level optional deps may constrain full local execution in this runtime.
    test_tiled_encode_uses_mlx_path_when_available
    test_tiled_encode_direct_path_for_short_audio
    test_tiled_encode_uses_offload_path
    test_tiled_encode_rejects_invalid_stride
    test_tiled_encode_routes_to_gpu_chunk_path
    test_tiled_encode_gpu_and_offload_outputs_match
    test_tiled_encode_offload_returns_cpu_tensor

Manual UI testing

Manual UI tests passed for available platform paths:

  • basic text2music generate
  • tiled encode scenarios
  • cover/repaint source-audio encode path
  • batch flow

Apple-specific (MPS/MLX) manual tests: not run (environment unavailable).

Summary by CodeRabbit

Release Notes

  • New Features

    • Enhanced VAE audio encoding with tiled processing and configurable overlap handling
    • Automatic chunk size selection based on available GPU memory
    • Multiple encoding pathways with adaptive fallback support
    • Optional CPU offload capability for improved memory efficiency
  • Tests

    • Comprehensive test coverage for VAE encoding functionality added

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 15, 2026

📝 Walkthrough

Walkthrough

Refactored VAE audio encoding by extracting tiled encoding logic into two new mixins (VaeEncodeMixin and VaeEncodeChunksMixin) within the handler package, removing old tiled helpers from AceStepHandler and replacing them with mixin-based implementations supporting GPU and CPU-offload strategies.

Changes

Cohort / File(s) Summary
Handler Package Exports
acestep/core/generation/handler/__init__.py
Added imports and exports for two new mixins: VaeEncodeMixin and VaeEncodeChunksMixin to public API.
AceStepHandler Updates
acestep/handler.py
Updated class to inherit VaeEncodeMixin and VaeEncodeChunksMixin; removed entire tiled_encode pipeline including _tiled_encode_gpu, _tiled_encode_offload_cpu, _tiled_encode_inner, and tiled_encode (203 lines removed).
VAE Encode Orchestrator
acestep/core/generation/handler/vae_encode.py
New mixin providing high-level tiled_encode method with automatic GPU memory-aware chunk sizing, MLX fast-path support, overlap handling, and delegation to GPU or CPU-offload strategies.
VAE Encode Chunks Implementations
acestep/core/generation/handler/vae_encode_chunks.py
New mixin providing two chunked encoding strategies: _tiled_encode_gpu (GPU-based) and _tiled_encode_offload_cpu (GPU encode with CPU latent buffering); both handle overlap expansion and latent trimming.
VAE Encode Tests
acestep/core/generation/handler/vae_encode_test.py
New test module with comprehensive coverage of tiled encoding paths, including MLX support, direct short-audio path, offload routing, stride validation, GPU consistency, and tensor return properties.

Sequence Diagram(s)

sequenceDiagram
    participant Client
    participant VaeEncodeMixin
    participant VaeEncodeChunksMixin
    participant VAE
    participant GPU/CPU Device

    Client->>VaeEncodeMixin: tiled_encode(audio, chunk_size, overlap)
    
    alt Small Input (single chunk)
        VaeEncodeMixin->>VAE: encode(audio)
        VAE->>GPU/CPU Device: compute latents
        GPU/CPU Device-->>VAE: latents
        VAE-->>VaeEncodeMixin: latents
    else Large Input (multi-chunk)
        VaeEncodeMixin->>VaeEncodeMixin: compute stride, validate params
        
        alt MLX Path Available
            VaeEncodeMixin->>GPU/CPU Device: mlx_vae.encode(audio)
            GPU/CPU Device-->>VaeEncodeMixin: latents
        else PyTorch Path
            VaeEncodeMixin->>VaeEncodeChunksMixin: delegate to _tiled_encode_*
            
            alt offload_latent_to_cpu = true
                VaeEncodeChunksMixin->>VAE: encode first chunk on GPU
                VAE->>GPU/CPU Device: compute latent chunk
                GPU/CPU Device-->>VAE: latent chunk
                VAE-->>VaeEncodeChunksMixin: latent chunk
                VaeEncodeChunksMixin->>GPU/CPU Device: allocate CPU tensor
                loop For remaining chunks
                    VaeEncodeChunksMixin->>VAE: encode chunk (with overlap window)
                    VAE->>GPU/CPU Device: compute latent
                    GPU/CPU Device-->>VAE: latent
                    VAE-->>VaeEncodeChunksMixin: latent
                    VaeEncodeChunksMixin->>GPU/CPU Device: append to CPU tensor
                end
                VaeEncodeChunksMixin-->>VaeEncodeMixin: final CPU latent tensor
            else offload_latent_to_cpu = false
                loop For all chunks
                    VaeEncodeChunksMixin->>VAE: encode chunk on GPU
                    VAE->>GPU/CPU Device: compute latent
                    GPU/CPU Device-->>VAE: latent
                    VAE-->>VaeEncodeChunksMixin: latent
                    VaeEncodeChunksMixin->>GPU/CPU Device: concatenate on GPU
                end
                VaeEncodeChunksMixin-->>VaeEncodeMixin: final GPU latent tensor
            end
        end
    end
    
    VaeEncodeMixin->>VaeEncodeMixin: normalize output shape
    VaeEncodeMixin-->>Client: latents
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • ChuxiJ

Poem

🐰 Encode chunks now hop with grace,
GPU and CPU in the right place,
Overlap windows dance to the beat,
Tiled paths make latents complete! 🎵

🚥 Pre-merge checks | ✅ 4
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'Refactor(handler part 13): vae encode' clearly and specifically describes the main change: refactoring VAE encoding logic into dedicated mixins, which is the core purpose of the PR.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Merge Conflict Detection ✅ Passed ✅ No merge conflicts detected when merging into main

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@acestep/llm_inference.py`:
- Around line 117-125: The _cleanup_torch_distributed_state function currently
uses a broad except Exception; change it to catch only expected errors: handle
ImportError when importing torch.distributed, and handle RuntimeError (or
specific torch.distributed exceptions) around dist.is_available(),
dist.is_initialized(), and dist.destroy_process_group(); log each error via
logger.warning with context (e.g., "[LLM vLLM] Failed to clean torch distributed
state: ...") and allow any other unexpected exceptions to propagate so they
aren't silently swallowed. Ensure you keep the same behavior of warning before
calling dist.destroy_process_group() and reference the same symbols
(_cleanup_torch_distributed_state, torch.distributed as dist,
dist.destroy_process_group, logger).
🧹 Nitpick comments (2)
acestep/core/generation/handler/vae_decode_chunks.py (1)

13-166: Add type hints to the new chunked decode helpers for clarity.

The mixin methods are new and untyped; adding basic annotations improves readability and matches the repo’s typing guideline.

Suggested signature annotations
-    def _tiled_decode_inner(self, latents, chunk_size, overlap, offload_wav_to_cpu):
+    def _tiled_decode_inner(
+        self,
+        latents: torch.Tensor,
+        chunk_size: int,
+        overlap: int,
+        offload_wav_to_cpu: bool,
+    ) -> torch.Tensor:
@@
-    def _tiled_decode_gpu(self, latents, stride, overlap, num_steps):
+    def _tiled_decode_gpu(
+        self,
+        latents: torch.Tensor,
+        stride: int,
+        overlap: int,
+        num_steps: int,
+    ) -> torch.Tensor:
@@
-    def _tiled_decode_offload_cpu(self, latents, bsz, latent_frames, stride, overlap, num_steps):
+    def _tiled_decode_offload_cpu(
+        self,
+        latents: torch.Tensor,
+        bsz: int,
+        latent_frames: int,
+        stride: int,
+        overlap: int,
+        num_steps: int,
+    ) -> torch.Tensor:

As per coding guidelines: Add type hints for new/modified functions when practical in Python.

acestep/core/generation/handler/vae_decode.py (1)

16-127: Add type hints to the new decode orchestration APIs.

These new mixin methods are untyped; annotating inputs/outputs improves clarity and aligns with the typing guideline.

Suggested signature annotations
-    def tiled_decode(
-        self,
-        latents,
-        chunk_size: Optional[int] = None,
-        overlap: int = 64,
-        offload_wav_to_cpu: Optional[bool] = None,
-    ):
+    def tiled_decode(
+        self,
+        latents: torch.Tensor,
+        chunk_size: Optional[int] = None,
+        overlap: int = 64,
+        offload_wav_to_cpu: Optional[bool] = None,
+    ) -> torch.Tensor:
@@
-    def _tiled_decode_cpu_fallback(self, latents):
+    def _tiled_decode_cpu_fallback(self, latents: torch.Tensor) -> torch.Tensor:
@@
-    def _decode_on_cpu(self, latents):
+    def _decode_on_cpu(self, latents: torch.Tensor) -> torch.Tensor:

As per coding guidelines: Add type hints for new/modified functions when practical in Python.

Comment on lines 117 to 125
def _cleanup_torch_distributed_state(self) -> None:
"""Destroy default torch distributed process group when already initialized."""
try:
import torch.distributed as dist
if dist.is_available() and dist.is_initialized():
logger.warning("[LLM vLLM] Destroying stale default process group before/after vLLM lifecycle")
dist.destroy_process_group()
except Exception as exc:
logger.warning(f"[LLM vLLM] Failed to clean torch distributed state: {exc}")
Copy link
Contributor

@coderabbitai coderabbitai bot Feb 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Avoid broad except Exception in distributed cleanup.

Catching Exception here will hide unexpected bugs; prefer explicit exceptions for import/runtime issues.

🔧 Suggested fix
-        except Exception as exc:
+        except (ImportError, RuntimeError, ValueError) as exc:
             logger.warning(f"[LLM vLLM] Failed to clean torch distributed state: {exc}")

As per coding guidelines: “Handle errors explicitly in Python; avoid bare except.”

🧰 Tools
🪛 Ruff (0.15.0)

[warning] 124-124: Do not catch blind exception: Exception

(BLE001)

🤖 Prompt for AI Agents
In `@acestep/llm_inference.py` around lines 117 - 125, The
_cleanup_torch_distributed_state function currently uses a broad except
Exception; change it to catch only expected errors: handle ImportError when
importing torch.distributed, and handle RuntimeError (or specific
torch.distributed exceptions) around dist.is_available(), dist.is_initialized(),
and dist.destroy_process_group(); log each error via logger.warning with context
(e.g., "[LLM vLLM] Failed to clean torch distributed state: ...") and allow any
other unexpected exceptions to propagate so they aren't silently swallowed.
Ensure you keep the same behavior of warning before calling
dist.destroy_process_group() and reference the same symbols
(_cleanup_torch_distributed_state, torch.distributed as dist,
dist.destroy_process_group, logger).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pre existing, will fix

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@1larity, thank you for acknowledging this! I understand this was pre-existing code and you'll address the broad exception handling in a future update. The specific exception types (ImportError, RuntimeError, ValueError) will help ensure unexpected errors aren't silently swallowed.


🧠 Learnings used
Learnt from: CR
Repo: ace-step/ACE-Step-1.5 PR: 0
File: AGENTS.md:0-0
Timestamp: 2026-02-13T12:56:21.792Z
Learning: Applies to **/*.py : Handle errors explicitly in Python; avoid bare `except`

@1larity 1larity force-pushed the feat/decompose-handler-part10-pr3-vae-encode branch from 317138e to 56cfb74 Compare February 15, 2026 20:01
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (3)
acestep/core/generation/handler/vae_encode_chunks.py (1)

34-38: Consider deleting latent_chunk after extracting the core to free GPU memory sooner.

In the GPU path, latent_chunk is kept alive until the next iteration. For memory-constrained GPUs, explicitly deleting it (like audio_chunk) could help.

♻️ Optional memory optimization
             latent_core = latent_chunk[:, :, trim_start:end_idx]
             encoded_latent_list.append(latent_core)
-            del audio_chunk
+            del audio_chunk, latent_chunk
acestep/core/generation/handler/vae_encode.py (1)

15-26: Consider adding type hints to the method signature for clarity.

The docstring documents the parameters well, but explicit type hints would improve IDE support and catch type errors earlier.

♻️ Optional type hints
-    def tiled_encode(self, audio, chunk_size=None, overlap=None, offload_latent_to_cpu=True):
+    def tiled_encode(
+        self,
+        audio: torch.Tensor,
+        chunk_size: Optional[int] = None,
+        overlap: Optional[int] = None,
+        offload_latent_to_cpu: bool = True,
+    ) -> torch.Tensor:
acestep/core/generation/handler/vae_encode_test.py (1)

69-146: Consider adding a test for 2D input shape handling.

The tiled_encode method handles both 2D [channels, samples] and 3D [batch, channels, samples] inputs. Adding a test for the 2D path would improve coverage.

🧪 Optional test for 2D input
def test_tiled_encode_handles_2d_input(self):
    """2D input should be unsqueezed, processed, and squeezed back."""
    host = _Host()
    audio_2d = torch.zeros(2, 16)  # [channels, samples]
    out = host.tiled_encode(audio_2d, chunk_size=20, overlap=2)
    # Output should be 2D: [latent_channels, latent_frames]
    self.assertEqual(len(out.shape), 2)
    self.assertEqual(out.shape[0], 4)  # latent_channels

@ChuxiJ ChuxiJ merged commit 30fe0bf into ace-step:main Feb 15, 2026
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants