Skip to content

[FEATURE] Collate Function for Batching Variable-Length Lag Sets #11

@GongJr0

Description

@GongJr0

Feature Details

Implement a collate_fn that takes a list of per-sample sparse-lag items and returns padded, mask-aware tensors suitable for the model. Each sample contains selected lag features and categorical IDs; sequences have variable $K$ (number of selected lags). The collate must:

  • Pad to $K_{max}$ within the batch,
  • Produce a boolean pad mask (True = pad),
  • Keep dtypes consistent (values float, IDs int64, masks bool),

Input (per sample)

{
  "vals":     np.ndarray | torch.Tensor,  # shape: (K_i, F_val)
  "lag_ids":  np.ndarray | torch.Tensor,  # shape: (K_i,), int
  "ticker_id": int,                        # scalar
  # optional
  "sector_id": int,
  "meta": {...}                            # optional passthrough
}

Output (batched)

vals:       torch.FloatTensor   # (B, K_max, F_val)
lag_ids:    torch.LongTensor    # (B, K_max)
ticker_ids: torch.LongTensor    # (B,)
sector_ids: torch.LongTensor | None
pad_mask:   torch.BoolTensor    # (B, K_max), True == PAD
lengths:    torch.LongTensor    # (B,), valid K_i per row
meta:       list[dict] | None   # optional passthrough

Affected Modules

As stated in the parent issue.

Implementation Checklist

  • CollatedBatch typed container (NamedTuple / dataclass) for outputs.
  • Implement collate_variable_lags(samples, *, pad_value=0.0, pad_idx=0, sort_by_len=False, pin_memory=False)
  • Fast-path for inputs already in torch.Tensor (avoid extra copies).
  • Add pinned-memory path
  • Unit tests:
    • Heterogeneous $K_i$ → correct padding & masks.
    • Dtype checks: vals==float32/64, IDs==int64, mask==bool.
    • Optional fields present/absent.
    • Sorting on/off preserves content; if sorted, return restore_idx.
    • Edge cases: $K_i=0$ (all pad), $K_i=K_{max}$, $B=1$.
    • Perf sanity ($B=256$, $K_{max}=64$).

Limitations

As stated in the parent issue.

Metadata

Metadata

Assignees

Labels

featureImplementation tracking for approved features

Projects

Status

In progress

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions