Notes
These are very high level notes (IN PROGRESS)
Mamba: Linear-Time Sequence Modeling with Selective State Spaces (Gu, Dao 2023)
- scales linearly with context length
- “First, we identify a key limitation of prior models: the ability to efficiently select data in an input-dependent manner.” – Q: was this the initial insight that led to mamba? -> selectivity = efficiency
- selective mechanism works by making some of the parameters in S4 input-dependent
- parameters B and C are now input-dependent functions, sB(x) and sC(x) with the shape BLN
- ∆ has term s∆(x) with shape BLD which acts at each step for more selectivity of the input
- Q: so the L in all of these makes Mamba time-varying whereas S4 is time-invariant?
- the hardware-aware algorithm -> collaborate with a GPU Jedi
- so there is a trade-off…a larger hidden state dimension increases expressivity at the cost of efficiency -> Q: how exactly is expressivity in an SSM different than an attention based transformer (think deeply here)
- okay, so SSM parameters are loaded from HBM to SRAM
- then perform the computationaly intensive steps in SRAM (discretization and recurrence)
- then write final outputs back to HBM
- okay, then a sort of FlashAttention thing for backpropogation (recompute intermediate states rather than store them)…saves on memory usage!
The Selection Mechanism
- prior SSM’s like S4 could be calculated very efficiently using either convolution or recurrence. But in order to maintain efficiency they had to be time and input invariant.
- authors were motivated to change this in Mamba to allow the hidden state to compress the context and “selectively” attend to the tokens that matter.
- that meant some of the parameters had to become input-dependent. so B, C, and ∆ became input-dependent by expanding dimensionality with L.
- doing this made Mamba time and input varying, eliminating the possibility of calculation with convolution (recurrence only mode now).
- that would’ve eliminated efficiency of the SSM except for three things:
- 1) kernel fusion - the discretization step, the scan, and the multiplication with C (which is the last step prior to getting an output y) are fused together in one kernel
- 2) parallel scan - don’t perform sequentially, perform in parallel
- 3) recomputation - for backpropogation recompute intermediate states rather than store them
- and do the majority of these activities in the SRAM (rather than HBM) for speed
- NOTE: the authors point out the selection mechanism can be thought of as being very similar to a gating mechanism or hypernetwork. a new term for an old idea.