-
Notifications
You must be signed in to change notification settings - Fork 0
Open
Labels
featureImplementation tracking for approved featuresImplementation tracking for approved features
Description
Feature Details
Implement a collate_fn that takes a list of per-sample sparse-lag items and returns padded, mask-aware tensors suitable for the model. Each sample contains selected lag features and categorical IDs; sequences have variable
- Pad to
$K_{max}$ within the batch, - Produce a boolean pad mask (True = pad),
- Keep dtypes consistent (values
float, IDsint64, masksbool),
Input (per sample)
{
"vals": np.ndarray | torch.Tensor, # shape: (K_i, F_val)
"lag_ids": np.ndarray | torch.Tensor, # shape: (K_i,), int
"ticker_id": int, # scalar
# optional
"sector_id": int,
"meta": {...} # optional passthrough
}Output (batched)
vals: torch.FloatTensor # (B, K_max, F_val)
lag_ids: torch.LongTensor # (B, K_max)
ticker_ids: torch.LongTensor # (B,)
sector_ids: torch.LongTensor | None
pad_mask: torch.BoolTensor # (B, K_max), True == PAD
lengths: torch.LongTensor # (B,), valid K_i per row
meta: list[dict] | None # optional passthroughAffected Modules
As stated in the parent issue.
Implementation Checklist
-
CollatedBatchtyped container (NamedTuple/dataclass) for outputs. - Implement
collate_variable_lags(samples, *, pad_value=0.0, pad_idx=0, sort_by_len=False, pin_memory=False) - Fast-path for inputs already in
torch.Tensor(avoid extra copies). - Add pinned-memory path
- Unit tests:
• Heterogeneous$K_i$ → correct padding & masks.
• Dtype checks:vals==float32/64,IDs==int64,mask==bool.
• Optional fields present/absent.
• Sorting on/off preserves content; if sorted, return restore_idx.
• Edge cases:$K_i=0$ (all pad),$K_i=K_{max}$ ,$B=1$ .
• Perf sanity ($B=256$ ,$K_{max}=64$ ).
Limitations
As stated in the parent issue.
Metadata
Metadata
Assignees
Labels
featureImplementation tracking for approved featuresImplementation tracking for approved features
Projects
Status
In progress