Refactor(handler part 13): vae encode#594
Conversation
📝 WalkthroughWalkthroughRefactored VAE audio encoding by extracting tiled encoding logic into two new mixins (VaeEncodeMixin and VaeEncodeChunksMixin) within the handler package, removing old tiled helpers from AceStepHandler and replacing them with mixin-based implementations supporting GPU and CPU-offload strategies. Changes
Sequence Diagram(s)sequenceDiagram
participant Client
participant VaeEncodeMixin
participant VaeEncodeChunksMixin
participant VAE
participant GPU/CPU Device
Client->>VaeEncodeMixin: tiled_encode(audio, chunk_size, overlap)
alt Small Input (single chunk)
VaeEncodeMixin->>VAE: encode(audio)
VAE->>GPU/CPU Device: compute latents
GPU/CPU Device-->>VAE: latents
VAE-->>VaeEncodeMixin: latents
else Large Input (multi-chunk)
VaeEncodeMixin->>VaeEncodeMixin: compute stride, validate params
alt MLX Path Available
VaeEncodeMixin->>GPU/CPU Device: mlx_vae.encode(audio)
GPU/CPU Device-->>VaeEncodeMixin: latents
else PyTorch Path
VaeEncodeMixin->>VaeEncodeChunksMixin: delegate to _tiled_encode_*
alt offload_latent_to_cpu = true
VaeEncodeChunksMixin->>VAE: encode first chunk on GPU
VAE->>GPU/CPU Device: compute latent chunk
GPU/CPU Device-->>VAE: latent chunk
VAE-->>VaeEncodeChunksMixin: latent chunk
VaeEncodeChunksMixin->>GPU/CPU Device: allocate CPU tensor
loop For remaining chunks
VaeEncodeChunksMixin->>VAE: encode chunk (with overlap window)
VAE->>GPU/CPU Device: compute latent
GPU/CPU Device-->>VAE: latent
VAE-->>VaeEncodeChunksMixin: latent
VaeEncodeChunksMixin->>GPU/CPU Device: append to CPU tensor
end
VaeEncodeChunksMixin-->>VaeEncodeMixin: final CPU latent tensor
else offload_latent_to_cpu = false
loop For all chunks
VaeEncodeChunksMixin->>VAE: encode chunk on GPU
VAE->>GPU/CPU Device: compute latent
GPU/CPU Device-->>VAE: latent
VAE-->>VaeEncodeChunksMixin: latent
VaeEncodeChunksMixin->>GPU/CPU Device: concatenate on GPU
end
VaeEncodeChunksMixin-->>VaeEncodeMixin: final GPU latent tensor
end
end
end
VaeEncodeMixin->>VaeEncodeMixin: normalize output shape
VaeEncodeMixin-->>Client: latents
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@acestep/llm_inference.py`:
- Around line 117-125: The _cleanup_torch_distributed_state function currently
uses a broad except Exception; change it to catch only expected errors: handle
ImportError when importing torch.distributed, and handle RuntimeError (or
specific torch.distributed exceptions) around dist.is_available(),
dist.is_initialized(), and dist.destroy_process_group(); log each error via
logger.warning with context (e.g., "[LLM vLLM] Failed to clean torch distributed
state: ...") and allow any other unexpected exceptions to propagate so they
aren't silently swallowed. Ensure you keep the same behavior of warning before
calling dist.destroy_process_group() and reference the same symbols
(_cleanup_torch_distributed_state, torch.distributed as dist,
dist.destroy_process_group, logger).
🧹 Nitpick comments (2)
acestep/core/generation/handler/vae_decode_chunks.py (1)
13-166: Add type hints to the new chunked decode helpers for clarity.The mixin methods are new and untyped; adding basic annotations improves readability and matches the repo’s typing guideline.
Suggested signature annotations
- def _tiled_decode_inner(self, latents, chunk_size, overlap, offload_wav_to_cpu): + def _tiled_decode_inner( + self, + latents: torch.Tensor, + chunk_size: int, + overlap: int, + offload_wav_to_cpu: bool, + ) -> torch.Tensor: @@ - def _tiled_decode_gpu(self, latents, stride, overlap, num_steps): + def _tiled_decode_gpu( + self, + latents: torch.Tensor, + stride: int, + overlap: int, + num_steps: int, + ) -> torch.Tensor: @@ - def _tiled_decode_offload_cpu(self, latents, bsz, latent_frames, stride, overlap, num_steps): + def _tiled_decode_offload_cpu( + self, + latents: torch.Tensor, + bsz: int, + latent_frames: int, + stride: int, + overlap: int, + num_steps: int, + ) -> torch.Tensor:As per coding guidelines: Add type hints for new/modified functions when practical in Python.
acestep/core/generation/handler/vae_decode.py (1)
16-127: Add type hints to the new decode orchestration APIs.These new mixin methods are untyped; annotating inputs/outputs improves clarity and aligns with the typing guideline.
Suggested signature annotations
- def tiled_decode( - self, - latents, - chunk_size: Optional[int] = None, - overlap: int = 64, - offload_wav_to_cpu: Optional[bool] = None, - ): + def tiled_decode( + self, + latents: torch.Tensor, + chunk_size: Optional[int] = None, + overlap: int = 64, + offload_wav_to_cpu: Optional[bool] = None, + ) -> torch.Tensor: @@ - def _tiled_decode_cpu_fallback(self, latents): + def _tiled_decode_cpu_fallback(self, latents: torch.Tensor) -> torch.Tensor: @@ - def _decode_on_cpu(self, latents): + def _decode_on_cpu(self, latents: torch.Tensor) -> torch.Tensor:As per coding guidelines: Add type hints for new/modified functions when practical in Python.
acestep/llm_inference.py
Outdated
| def _cleanup_torch_distributed_state(self) -> None: | ||
| """Destroy default torch distributed process group when already initialized.""" | ||
| try: | ||
| import torch.distributed as dist | ||
| if dist.is_available() and dist.is_initialized(): | ||
| logger.warning("[LLM vLLM] Destroying stale default process group before/after vLLM lifecycle") | ||
| dist.destroy_process_group() | ||
| except Exception as exc: | ||
| logger.warning(f"[LLM vLLM] Failed to clean torch distributed state: {exc}") |
There was a problem hiding this comment.
Avoid broad except Exception in distributed cleanup.
Catching Exception here will hide unexpected bugs; prefer explicit exceptions for import/runtime issues.
🔧 Suggested fix
- except Exception as exc:
+ except (ImportError, RuntimeError, ValueError) as exc:
logger.warning(f"[LLM vLLM] Failed to clean torch distributed state: {exc}")As per coding guidelines: “Handle errors explicitly in Python; avoid bare except.”
🧰 Tools
🪛 Ruff (0.15.0)
[warning] 124-124: Do not catch blind exception: Exception
(BLE001)
🤖 Prompt for AI Agents
In `@acestep/llm_inference.py` around lines 117 - 125, The
_cleanup_torch_distributed_state function currently uses a broad except
Exception; change it to catch only expected errors: handle ImportError when
importing torch.distributed, and handle RuntimeError (or specific
torch.distributed exceptions) around dist.is_available(), dist.is_initialized(),
and dist.destroy_process_group(); log each error via logger.warning with context
(e.g., "[LLM vLLM] Failed to clean torch distributed state: ...") and allow any
other unexpected exceptions to propagate so they aren't silently swallowed.
Ensure you keep the same behavior of warning before calling
dist.destroy_process_group() and reference the same symbols
(_cleanup_torch_distributed_state, torch.distributed as dist,
dist.destroy_process_group, logger).
There was a problem hiding this comment.
Pre existing, will fix
There was a problem hiding this comment.
@1larity, thank you for acknowledging this! I understand this was pre-existing code and you'll address the broad exception handling in a future update. The specific exception types (ImportError, RuntimeError, ValueError) will help ensure unexpected errors aren't silently swallowed.
🧠 Learnings used
Learnt from: CR
Repo: ace-step/ACE-Step-1.5 PR: 0
File: AGENTS.md:0-0
Timestamp: 2026-02-13T12:56:21.792Z
Learning: Applies to **/*.py : Handle errors explicitly in Python; avoid bare `except`
317138e to
56cfb74
Compare
There was a problem hiding this comment.
🧹 Nitpick comments (3)
acestep/core/generation/handler/vae_encode_chunks.py (1)
34-38: Consider deletinglatent_chunkafter extracting the core to free GPU memory sooner.In the GPU path,
latent_chunkis kept alive until the next iteration. For memory-constrained GPUs, explicitly deleting it (likeaudio_chunk) could help.♻️ Optional memory optimization
latent_core = latent_chunk[:, :, trim_start:end_idx] encoded_latent_list.append(latent_core) - del audio_chunk + del audio_chunk, latent_chunkacestep/core/generation/handler/vae_encode.py (1)
15-26: Consider adding type hints to the method signature for clarity.The docstring documents the parameters well, but explicit type hints would improve IDE support and catch type errors earlier.
♻️ Optional type hints
- def tiled_encode(self, audio, chunk_size=None, overlap=None, offload_latent_to_cpu=True): + def tiled_encode( + self, + audio: torch.Tensor, + chunk_size: Optional[int] = None, + overlap: Optional[int] = None, + offload_latent_to_cpu: bool = True, + ) -> torch.Tensor:acestep/core/generation/handler/vae_encode_test.py (1)
69-146: Consider adding a test for 2D input shape handling.The
tiled_encodemethod handles both 2D[channels, samples]and 3D[batch, channels, samples]inputs. Adding a test for the 2D path would improve coverage.🧪 Optional test for 2D input
def test_tiled_encode_handles_2d_input(self): """2D input should be unsqueezed, processed, and squeezed back.""" host = _Host() audio_2d = torch.zeros(2, 16) # [channels, samples] out = host.tiled_encode(audio_2d, chunk_size=20, overlap=2) # Output should be 2D: [latent_channels, latent_frames] self.assertEqual(len(out.shape), 2) self.assertEqual(out.shape[0], 4) # latent_channels
Summary
Extract VAE encode/tiled encode logic from
acestep/handler.pyinto focused mixins while preserving runtime behaviour.This continues the handler FD plan by slimming the facade and moving encode responsibilities into dedicated modules.
What changed
Decomposition
acestep/core/generation/handler/vae_encode.pyVaeEncodeMixintiled_encode(...)orchestration and path selectionacestep/core/generation/handler/vae_encode_chunks.pyVaeEncodeChunksMixin_tiled_encode_gpu(...)_tiled_encode_offload_cpu(...)acestep/core/generation/handler/__init__.pyacestep/handler.pyhandler.pyfacade section.Tests
acestep/core/generation/handler/vae_encode_test.pyBehavioral parity notes
Validation
py_compilepassed on changed files.test_tiled_encode_uses_mlx_path_when_available
test_tiled_encode_direct_path_for_short_audio
test_tiled_encode_uses_offload_path
test_tiled_encode_rejects_invalid_stride
test_tiled_encode_routes_to_gpu_chunk_path
test_tiled_encode_gpu_and_offload_outputs_match
test_tiled_encode_offload_returns_cpu_tensor
Manual UI testing
Manual UI tests passed for available platform paths:
Apple-specific (MPS/MLX) manual tests: not run (environment unavailable).
Summary by CodeRabbit
Release Notes
New Features
Tests