Skip to content

Unified meta-reinforcement learning benchmark for fast adaptation with State Space Models (SSM), test-time improvement, and modular policy orchestration. Includes automated training, evaluation, adaptation, and CI/CD pipelines.

License

Notifications You must be signed in to change notification settings

sunghunkwag/SSM-MetaRL-TestCompute

Repository files navigation

SSM-MetaRL-TestCompute

A research framework combining neural state modeling (inspired by State Space Models), Meta-Learning (MAML), and Test-Time Adaptation for reinforcement learning.

Tests Python License Docker Open In Colab

⚠️ Important Clarification

This framework uses neural networks with explicit state representation for temporal modeling, inspired by State Space Model concepts. It is NOT a structured SSM like S4/Mamba/LRU.

What We Have ✅

  • Explicit state representation with recurrent processing
  • Residual state transitions using MLPs
  • Meta-learning via MAML for fast adaptation
  • Test-time adaptation with gradient descent
  • Compatible with standard RL environments

What We Don't Have ❌

  • Structured SSM parameterization (HiPPO, diagonal matrices, low-rank, etc.)
  • Continuous-time dynamics with discretization schemes
  • FFT-based convolution mode for parallel processing
  • Sub-quadratic complexity (actual: O(T·d²) per sequence, similar to RNNs)

Why This Design?

  • Simpler to implement and debug than structured SSMs
  • Compatible with existing meta-learning algorithms (MAML, etc.)
  • Sufficient for proof-of-concept experiments and RL tasks
  • Easier to extend with custom architectures

Complexity Analysis

Component This Framework Structured SSM Traditional RNN
Forward pass O(T·d²) O(T·d) or O(T log T) O(T·d²)
Parameters ~50K ~50K ~50K
Parallelizable ❌ (recurrent) ✅ (convolution mode) ❌ (recurrent)

Honest assessment: Our implementation has similar complexity to GRU/LSTM, not better. The main contribution is the meta-learning + adaptation framework, not the state modeling itself.


Features

  • Neural State Modeling for temporal dynamics (explicit state with residual transitions)
  • Meta-Learning (MAML) for fast adaptation across tasks
  • Test-Time Adaptation for online model improvement
  • Modular Architecture with clean, testable components
  • Gymnasium Integration for RL environment compatibility
  • Test Suite with automated CI/CD
  • Docker Container ready for deployment
  • Baseline Comparisons with LSTM, GRU, Transformer implementations

Project Structure

  • core/: Core model implementations
    • ssm.py: Neural state model (MLP-based state transitions, returns state)
  • meta_rl/: Meta-learning algorithms
    • meta_maml.py: MetaMAML implementation (handles stateful models and time series input)
  • adaptation/: Test-time adaptation
    • test_time_adaptation.py: Adapter class (manages hidden state updates internally)
  • env_runner/: Environment utilities
    • environment.py: Gymnasium environment wrapper
  • experiments/: Experiment scripts and benchmarks
    • quick_benchmark.py: Quick benchmark suite
    • serious_benchmark.py: High-dimensional MuJoCo benchmarks (work in progress)
    • task_distributions.py: Meta-learning task distributions
    • baselines.py: LSTM, GRU, Transformer baseline implementations
  • tests/: Test suite for all components

Interactive Demo

Try it now: Open In Colab

Run the complete demo in your browser with Google Colab - no installation required!

Demo Notebook Features

  • Correct API Usage: Demonstrates proper MetaMAML and Adapter APIs
  • Time Series Handling: Proper 3D tensor shapes (batch, time, features)
  • Hidden State Management: Correct initialization and propagation
  • Visualization: Loss curves and adaptation progress
  • Evaluation: Model performance metrics

Advanced Benchmarks (Work in Progress)

We're implementing benchmarks on high-dimensional MuJoCo tasks with baseline comparisons.

Motivation

Simple benchmarks (CartPole, Pendulum) have limitations for research validation:

  • Low dimensional (4-8 state dims)
  • Simple dynamics
  • Limited baseline comparisons

Planned Experiments

High-Dimensional Tasks

  • HalfCheetah-v4: 17-dim state, 6-dim action
  • Ant-v4: 27-dim state, 8-dim action
  • Humanoid-v4: 376-dim state, 17-dim action

Baseline Comparisons

  • LSTM-MAML (76K params, O(n²) complexity)
  • GRU-MAML (57K params, O(n²) complexity)
  • Transformer-MAML (400K params, O(n²) complexity)
  • MLP-MAML (20K params, no sequence modeling)
  • This Framework (53K params, O(n²) complexity, with meta-learning)

Note: Full results coming soon. Current focus is on framework correctness and API stability.


Quick Start

Installation

git clone https://github.com/sunghunkwag/SSM-MetaRL-TestCompute.git
cd SSM-MetaRL-TestCompute
pip install -e .

# For development:
pip install -e .[dev]

Docker Installation

# Pull the latest container
docker pull ghcr.io/sunghunkwag/ssm-metarl-testcompute:latest

# Run main script
docker run --rm ghcr.io/sunghunkwag/ssm-metarl-testcompute:latest python main.py --env_name CartPole-v1

# Run benchmark
docker run --rm ghcr.io/sunghunkwag/ssm-metarl-testcompute:latest python experiments/quick_benchmark.py

Run Main Script

# Train on CartPole environment
python main.py --env_name CartPole-v1 --num_epochs 20

# Train on Pendulum environment
python main.py --env_name Pendulum-v1 --num_epochs 10

Run Benchmark

python experiments/quick_benchmark.py

Run Tests

pytest

Test Results

The framework has been tested with the following results:

Test Category Status Pass Rate
Unit Tests ✅ All Passing 100%
CI/CD Pipeline ✅ Automated Python 3.8-3.11
CartPole-v1 ✅ Passed Loss reduction: 91.5% - 93.7%
Pendulum-v1 ✅ Passed Loss reduction: 95.9%
Benchmarks ✅ Passed Loss reduction: 86.8%

Verified Functionality

  • ✅ Neural State Model - All features working
  • ✅ MetaMAML - Meta-learning operational
  • ✅ Test-Time Adaptation - Adaptation effects confirmed
  • ✅ Environment Runner - Multiple environments supported
  • ✅ Docker Container - Automated builds and deployment

Core Components

State Space Model (core/ssm.py)

The model in core/ssm.py implements neural state transitions with explicit state tracking.

Architecture: MLP-based state transitions with residual connections:

h_t = MLP(h_{t-1}) + Linear(x_t)
y_t = MLP(h_t) + Linear(x_t)

API:

  • forward(x, hidden_state) returns tuple: (output, next_hidden_state)
  • init_hidden(batch_size) provides initial hidden state

Constructor Arguments:

  • state_dim (int): Internal state dimension
  • input_dim (int): Input feature dimension
  • output_dim (int): Output feature dimension
  • hidden_dim (int): Hidden layer dimension within MLPs
  • device (str or torch.device)

Example usage:

import torch
from core.ssm import StateSpaceModel

model = StateSpaceModel(state_dim=128, input_dim=64, output_dim=32, device='cpu')
batch_size = 4
input_x = torch.randn(batch_size, 64)
current_hidden = model.init_hidden(batch_size)

# Forward pass requires current state and returns next state
output, next_hidden = model(input_x, current_hidden)
print(output.shape)       # torch.Size([4, 32])
print(next_hidden.shape)  # torch.Size([4, 128])

MetaMAML

The MetaMAML class in meta_rl/meta_maml.py implements MAML.

Key Features:

  • Handles stateful models (with explicit state passing)
  • Supports time series input (B, T, D)
  • API: meta_update takes tasks (list of tuples) and initial_hidden_state

Time Series Input Handling: Input data should be shaped (batch_size, time_steps, features). MAML processes sequences internally.

Example with time series:

import torch
import torch.nn.functional as F
from meta_rl.meta_maml import MetaMAML
from core.ssm import StateSpaceModel

model = StateSpaceModel(state_dim=64, input_dim=32, output_dim=16, device='cpu')
maml = MetaMAML(model, inner_lr=0.01, outer_lr=0.001)

# Time series input: (batch=4, time_steps=10, features=32)
support_x = torch.randn(4, 10, 32)
support_y = torch.randn(4, 10, 16)
query_x = torch.randn(4, 10, 32)
query_y = torch.randn(4, 10, 16)

# Prepare tasks as list of tuples
tasks = []
for i in range(4):
    tasks.append((support_x[i:i+1], support_y[i:i+1], query_x[i:i+1], query_y[i:i+1]))

# Initialize hidden state
initial_hidden = model.init_hidden(batch_size=4)

# Call meta_update with tasks list and initial state
loss = maml.meta_update(tasks=tasks, initial_hidden_state=initial_hidden, loss_fn=F.mse_loss)
print(f"Meta Loss: {loss:.4f}")

Constructor Arguments:

  • model: The base model
  • inner_lr (float): Inner loop learning rate
  • outer_lr (float): Outer loop learning rate
  • first_order (bool): Use first-order MAML

Adapter (Test-Time Adaptation)

The Adapter class in adaptation/test_time_adaptation.py performs test-time adaptation.

Key Features:

  • API: update_step takes x, y (target), and hidden_state
  • Internally performs config.num_steps gradient updates per call
  • Properly detaches hidden state to prevent autograd errors
  • Returns (loss, steps_taken)

Constructor Arguments:

  • model: The model to adapt
  • config: An AdaptationConfig object with learning_rate and num_steps
  • device: Device string ('cpu' or 'cuda')

Example usage:

import torch
from adaptation.test_time_adaptation import Adapter, AdaptationConfig
from core.ssm import StateSpaceModel

model = StateSpaceModel(state_dim=64, input_dim=32, output_dim=32, device='cpu')
config = AdaptationConfig(learning_rate=0.01, num_steps=5)
adapter = Adapter(model=model, config=config, device='cpu')

# Initialize hidden state
hidden_state = model.init_hidden(batch_size=1)

# Adaptation loop
for step in range(10):
    x = torch.randn(1, 32)
    y_target = torch.randn(1, 32)
    
    # Store current state
    current_hidden_state = hidden_state
    
    # Get next state prediction
    with torch.no_grad():
        output, hidden_state = model(x, current_hidden_state)
    
    # Adapt model
    loss, steps_taken = adapter.update_step(
        x=x,
        y=y_target,
        hidden_state=current_hidden_state
    )
    print(f"Step {step}, Loss: {loss:.4f}")

Docker Usage

Uses multi-stage build for efficient containerization with automated CI/CD.

Pull Pre-built Container:

# Latest version
docker pull ghcr.io/sunghunkwag/ssm-metarl-testcompute:latest

# Specific version
docker pull ghcr.io/sunghunkwag/ssm-metarl-testcompute:main

Build Locally:

docker build -t ssm-metarl .

Run:

# Run main script
docker run --rm ghcr.io/sunghunkwag/ssm-metarl-testcompute:latest python main.py --env_name Pendulum-v1

# Run tests
docker run --rm ghcr.io/sunghunkwag/ssm-metarl-testcompute:latest pytest

Recent Updates (v1.3.1)

Honesty Update (Latest)

  1. README Clarification (Commit: TBD)
    • Added: ⚠️ Important Clarification section
    • Clarified: This is neural network based, NOT structured SSM
    • Updated: Complexity claims to be accurate (O(T·d²))
    • Explained: What we have vs what we don't have
    • Impact: Honest representation of the framework

Demo Notebook Fixes (v1.3.0)

  1. MetaMAML API Correction

    • Fixed: Corrected meta_update() to use tasks list and initial_hidden_state
    • Impact: Demo now runs without errors in Colab
  2. Adapter API Correction

    • Fixed: Replaced non-existent adapt() with update_step()
    • Impact: Proper adaptation with loss tracking
  3. Data Shape Fixes

    • Fixed: Proper 3D tensor reshaping (batch, time, features)
    • Impact: MetaMAML processes sequences correctly
  4. Hidden State Management

    • Fixed: Proper initialization and propagation
    • Impact: State tracking works correctly

Requirements

  • Python >= 3.8
  • PyTorch >= 2.0
  • Gymnasium >= 1.0
  • NumPy
  • pytest (for development)

License

This project is licensed under the MIT License - see the LICENSE file for details.

Citation

If you use this framework in your research, please cite:

@software{ssm_metarl_testcompute,
  author = {Sung Hun Kwag},
  title = {SSM-MetaRL-TestCompute: A Framework for Meta-RL with Neural State Modeling},
  year = {2025},
  url = {https://github.com/sunghunkwag/SSM-MetaRL-TestCompute}
}

Acknowledgments

This framework builds upon research in:

  • State Space Models (conceptual inspiration)
  • Model-Agnostic Meta-Learning (MAML)
  • Test-time adaptation techniques
  • Reinforcement learning with Gymnasium

About

Unified meta-reinforcement learning benchmark for fast adaptation with State Space Models (SSM), test-time improvement, and modular policy orchestration. Includes automated training, evaluation, adaptation, and CI/CD pipelines.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages