State space models (MAMBA) combined with flow matching for high-quality image generation from sparse (20%) pixel observations. Supports zero-shot super-resolution at arbitrary scales.
- MAMBA Architecture: Linear-complexity state space models for efficient sequence processing
- Flow Matching: Continuous normalizing flows for high-quality generation
- Sparse Training: Learn from only 20% of pixels with deterministic masking
- Zero-Shot Super-Resolution: Generate at 64×, 96×, 128×, 256× without training at those resolutions
- Four Architectures: V1 (MAMBA baseline), V2 (bidirectional + perceiver), V3 (Morton curves), V4 (Transformer comparison)
- Unidirectional MAMBA with 6 layers
- Single cross-attention layer
- Row-major sequence ordering
- PSNR: ~28 dB, SSIM: ~0.85
- Bidirectional MAMBA: 4 forward + 4 backward = 8 layers
- Lightweight Perceiver: Query self-attention for spatial coherence
- Expected improvements:
- 70-80% reduction in background speckles
- +3-5 dB PSNR improvement
- Smoother, more coherent spatial fields
- Trade-off: +71% computational cost
- Same architecture as V1 (6 layers, same parameters)
- Morton (Z-order) curve: Better spatial locality in sequences
- Expected improvements:
- Better spatial coherence
- Reduced artifacts from spatially-aware processing
- +1-2 dB PSNR improvement
- Trade-off: Zero additional cost!
- Standard Transformer encoder instead of MAMBA
- Multi-head self-attention: Global context vs sequential state
- Purpose: Benchmark MAMBA's linear O(N) vs Transformer's quadratic O(N²)
- Expected:
- 10-20x slower training than V1
- Tests if global attention helps sparse neural fields
- Fair comparison with same depth and dimension
- Trade-off: Much higher computational cost
# Clone repository
git clone https://github.com/yourusername/MambaFlowMatching.git
cd MambaFlowMatching
# Install dependencies
pip install -r requirements.txtV1 (Baseline):
cd v1/training
./run_mamba_training.shV2 (Bidirectional + Perceiver):
cd v2/training
./run_mamba_v2_training.shV3 (Morton Curves - Recommended):
cd v3/training
./run_mamba_v3_training.shV4 (Transformer Comparison):
cd v4/training
./run_transformer_v4_training.shSuper-Resolution Evaluation:
cd v1/evaluation
./eval_superres.sh # Tests 64×, 96×, 128×, 256× resolutionsV1 vs V2 Comparison:
cd v2/evaluation
python eval_v1_vs_v2.py \
--v1_checkpoint ../../v1/training/checkpoints_mamba/mamba_best.pth \
--v2_checkpoint ../training/checkpoints_mamba_v2/mamba_v2_best.pth \
--num_samples 20MambaFlowMatching/
├── core/ # Core modules
│ ├── neural_fields/ # Fourier features, perceiver components
│ ├── sparse/ # Sparse dataset handling, metrics
│ └── diffusion/ # Flow matching utilities
│
├── v1/ # V1 Architecture (Baseline)
│ ├── training/ # Training scripts
│ │ ├── train_mamba_standalone.py
│ │ ├── run_mamba_training.sh
│ │ ├── monitor_training.sh
│ │ └── stop_mamba_training.sh
│ └── evaluation/ # Evaluation scripts
│ ├── eval_superresolution.py
│ ├── eval_superres.sh
│ ├── eval_sde_multiscale.py
│ └── eval_sde.sh
│
├── v2/ # V2 Architecture (Bidirectional)
│ ├── training/ # Training scripts
│ │ ├── train_mamba_v2.py
│ │ └── run_mamba_v2_training.sh
│ └── evaluation/ # Evaluation scripts
│ └── eval_v1_vs_v2.py
│
├── v3/ # V3 Architecture (Morton Curves)
│ ├── training/ # Training scripts
│ │ ├── train_mamba_v3_morton.py
│ │ └── run_mamba_v3_training.sh
│ └── evaluation/ # Evaluation scripts (TBD)
│
├── v4/ # V4 Architecture (Transformer)
│ ├── training/ # Training scripts
│ │ ├── train_transformer_v4.py
│ │ └── run_transformer_v4_training.sh
│ └── evaluation/ # Evaluation scripts (TBD)
│
├── docs/ # Documentation
│ ├── README.md # Original documentation
│ ├── README_V2.md # V2 architecture details
│ ├── README_V3.md # V3 Morton curves guide
│ ├── README_V4.md # V4 Transformer comparison guide
│ ├── README_SUPERRES.md # Super-resolution guide
│ ├── README_SDE.md # SDE sampling guide
│ ├── QUICKSTART_EVAL.md # Quick evaluation guide
│ ├── QUICKSTART_SDE.md # Quick SDE guide
│ └── TRAINING_README.md # Training guide
│
├── scripts/ # Utility scripts
│ ├── remote_setup.sh # Remote server setup
│ └── verify_deterministic_masking.py
│
└── requirements.txt # Python dependencies
Input Coordinates → Fourier Features → MAMBA (6 layers) → Cross-Attention → Decoder → Output
- MAMBA: 6 unidirectional layers (left → right)
- Cross-Attention: Single layer for input-query interaction
- d_model: 256 (default)
- Parameters: ~4M
Input Coordinates → Fourier Features → Bidirectional MAMBA (8 layers) → Lightweight Perceiver → Decoder → Output
- Bidirectional MAMBA: 4 forward + 4 backward = 8 layers
- Lightweight Perceiver: 2 iterations with query self-attention
- d_model: 256 (default)
- Parameters: ~5M
Key V2 Improvements:
- Bidirectional Context: Every pixel sees information from both directions
- Query Self-Attention: Neighboring query pixels communicate for spatial coherence
- Iterative Refinement: 2-iteration perceiver for coarse-to-fine processing
Input Coordinates → Fourier Features → Morton Reorder → MAMBA (6 layers) → Restore Order → Cross-Attention → Decoder → Output
- MAMBA: 6 unidirectional layers (same as V1)
- Morton Curves: Z-order sequencing for better spatial locality
- d_model: 256 (same as V1)
- Parameters: ~4M (identical to V1)
Key V3 Improvements:
- Better Spatial Locality: Neighbors in 2D are also neighbors in 1D sequence
- Zero Extra Cost: Same computational complexity as V1
- Clean Improvement: Only sequence ordering changes, architecture unchanged
Input Coordinates → Fourier Features → Positional Encoding → Transformer (6 layers) → Cross-Attention → Decoder → Output
- Transformer: 6 layers with multi-head self-attention
- Positional Encoding: Sinusoidal position embeddings
- d_model: 256 (same as V1)
- Parameters: ~4M (similar to V1)
- Complexity: O(N²) quadratic attention
Key V4 Purpose:
- Benchmark MAMBA: Compare linear vs quadratic complexity in practice
- Global vs Sequential: Test if full attention helps sparse neural fields
- Fair Comparison: Same depth and dimension as V1
- Research Baseline: Standard architecture for comparison
| Feature | V1 | V2 | V3 | V4 |
|---|---|---|---|---|
| Encoder | MAMBA | Bidirectional MAMBA | MAMBA | Transformer |
| Attention | 1 cross-attn | Perceiver + self-attn | 1 cross-attn | Self + Cross |
| Ordering | Row-major | Row-major | Morton curve | Row-major |
| Complexity | O(N) | O(N) | O(N) | O(N²) |
| d_model | 256 | 256 | 256 | 256 |
| Layers | 6 | 8 | 6 | 6 |
| Parameters | 4M | 5M | 4M | 4M |
| Compute Cost | 1.0x | 1.7x | 1.0x | 10-20x |
| Philosophy | Baseline | Architectural | Ordering | Comparison |
V1:
d_model=256
num_layers=6
batch_size=64
lr=1e-4
epochs=1000V2:
d_model=256
num_layers=8 # 4 forward + 4 backward
batch_size=64
lr=1e-4
epochs=1000
perceiver_iterations=2
perceiver_heads=8V3:
d_model=256 # Same as V1
num_layers=6 # Same as V1
batch_size=64
lr=1e-4
epochs=1000
morton_ordering=True # NEW: Enabled by defaultV4:
d_model=256 # Same as V1
num_layers=6 # Same as V1
num_heads=8 # Multi-head attention
dim_feedforward=1024 # 4x expansion ratio
batch_size=64
lr=1e-4
epochs=1000- CIFAR-10: 32×32 RGB images
- Sparse Sampling: 20% of pixels randomly selected (deterministic per image)
- Train/Val Split: Standard CIFAR-10 split
Three sampling methods supported:
-
Heun ODE Solver (default):
- Second-order accuracy
- Deterministic
- Good baseline quality
-
SDE Sampling:
- Adds Langevin dynamics
- Stochastic exploration
- Temperature-controlled noise
-
DDIM Sampling:
- Non-uniform timestep schedule
- Faster convergence
- Configurable stochasticity (eta)
- PSNR (Peak Signal-to-Noise Ratio): Measures reconstruction quality
- SSIM (Structural Similarity Index): Perceptual similarity metric
- MSE (Mean Squared Error): Pixel-wise error
- MAE (Mean Absolute Error): Average absolute difference
# V1 with custom settings
cd v1/training
D_MODEL=512 NUM_LAYERS=6 BATCH_SIZE=32 LR=5e-5 ./run_mamba_training.sh
# V2 with custom settings
cd v2/training
D_MODEL=256 NUM_LAYERS=8 BATCH_SIZE=32 LR=5e-5 ./run_mamba_v2_training.sh# V1
cd v1/training
./monitor_training.sh
# V2
cd v2/training
tail -f training_v2_output.log# V1
cd v1/training
./stop_mamba_training.sh
# V2
cd v2/training
kill $(cat training_v2.pid)See the docs/ directory for detailed documentation:
- README_V2.md: Comprehensive V2 architecture guide with design decisions
- README_V3.md: V3 Morton curves implementation and spatial locality
- README_V4.md: V4 Transformer comparison and complexity analysis
- README_SUPERRES.md: Super-resolution evaluation guide
- README_SDE.md: SDE and DDIM sampling methods
- QUICKSTART_EVAL.md: Quick evaluation reference
- QUICKSTART_SDE.md: Quick SDE sampling reference
- TRAINING_README.md: Detailed training guide
# Reduce batch size or model dimension
D_MODEL=128 BATCH_SIZE=32 ./run_mamba_training.sh# Reduce number of workers or model size
NUM_WORKERS=2 D_MODEL=256 ./run_mamba_training.sh- This is expected! Try V3 (Morton curves) for better spatial coherence with zero extra cost
- Or use V2 architecture for cleaner results through bidirectional processing
- V4 is 10-20x slower than V1 due to quadratic O(N²) attention
- This is expected and by design for comparison purposes
- Reduce batch size or use V1/V3 for faster training
Contributions welcome! Please open issues or pull requests.
MIT License - see LICENSE file for details.
- MAMBA: Gu & Dao (2023) - Mamba: Linear-Time Sequence Modeling
- Transformer: Vaswani et al. (2017) - Attention Is All You Need
- Flow Matching: Lipman et al. (2023) - Flow Matching for Generative Modeling
- Perceiver: Jaegle et al. (2021) - Perceiver: General Perception with Iterative Attention
- Neural Fields: Tancik et al. (2020) - Fourier Features Let Networks Learn High Frequency Functions
- Morton Curves: Morton, G.M. (1966) - A Computer Oriented Geodetic Data Base
For questions or issues, please open a GitHub issue.
Made with ❤️ for sparse neural field generation