Transformer networks as recurrent or state-space models

December 20, 2017 — June 17, 2024

language
machine learning
meta learning
neural nets
NLP
stringology
time series
Figure 1

The intersection of transformers and recurrent or state space models. Much to say here, but only scrapbook notes for now.

1 Tokens as recurrent state

See also RecurrentGPT (Zhou et al. 2023)

GitHub - aiwaves-cn/RecurrentGPT

RecurrentGPT replaces the vectorized elements (i.e., cell state, hidden state, input, and output) in a Long-short Term Memory RNN (LSTM) with natural language (i.e., paragraphs of texts), and simulates the recurrence mechanism with prompt engineering.

At each timestep t, RecurrentGPT receives a paragraph of text and a brief plan of the next paragraph, which are both generated in step t − 1. It then attends to the long-term memory, which contains the summaries of all previously generated paragraphs and can be stored on hard drives, and relevant paragraphs can be retrieved with semantic search.

RecurrentGPT also maintains a short-term memory that summarizes key information within recent timesteps in natural language and is updated at each time step. RecurrentGPT combines all aforementioned inputs in a prompt and asks the backbone LLM to generate a new paragraph, a short plan for the next paragraph, and updates the long-short term memory by rewriting the short-term memory and appending the summary of the output paragraph to the long-term memory.

2 RWKV

State-space, i.e. recurrent transformers, without (classic) attention. Suggestive connection to S4 models.

RWKV is inspired by Apple’s Attention Free Transformer. (Zhai et al. 2021). …

How to combine the best of transformers and RNNs? The main drawback of transformer-based models is that it can become challenging to run a model with a context window that is larger than a certain value, as the attention scores are computed simultaneously for the entire sequence.

RNNs natively support very long context lengths—only limited by the context length seen in training, but this can be extended to millions of tokens with careful coding. Currently, there are RWKV models trained on a context length of 8192 (ctx8192) and they are as fast as ctx1024 models and require the same amount of RAM.

The major drawbacks of traditional RNN models and how RWKV is different:

  1. Traditional RNN models are unable to utilize very long contexts (LSTM can only manage ~100 tokens when used as a LM). However, RWKV can utilize thousands of tokens and beyond…
  2. Traditional RNN models cannot be parallelized when training. RWKV is similar to a “linearized GPT” and it trains faster than GPT.

By combining both advantages into a single architecture, the hope is that RWKV can grow to become more than the sum of its parts.

Figure 2

2.1 Mamba

Interesting connection to transformers.

(Gu and Dao 2023):

Foundation models, now powering most of the exciting applications in deep learning, are almost universally based on the Transformer architecture and its core attention module. Many subquadratic-time architectures such as linear attention, gated convolution and recurrent models, and structured state space models (SSMs) have been developed to address Transformers’ computational inefficiency on long sequences, but they have not performed as well as attention on important modalities such as language. We identify that a key weakness of such models is their inability to perform content-based reasoning, and make several improvements. First, simply letting the SSM parameters be functions of the input addresses their weakness with discrete modalities, allowing the model to selectively propagate or forget information along the sequence length dimension depending on the current token. Second, even though this change prevents the use of efficient convolutions, we design a hardware-aware parallel algorithm in recurrent mode. We integrate these selective SSMs into a simplified end-to-end neural network architecture without attention or even MLP blocks (Mamba). Mamba enjoys fast inference (5× higher throughput than Transformers) and linear scaling in sequence length, and its performance improves on real data up to million-length sequences. As a general sequence model backbone, Mamba achieves state-of-the-art performance across several modalities such as language, audio, and genomics. On language modeling, our Mamba-3B model outperforms Transformers of the same size and matches Transformers twice its size, both in pretraining and downstream evaluation.

See

and (Cai et al. 2024; Gu and Dao 2023; Hu et al. 2024; Patro and Agneeswaran 2024).

There is now a sequel, Mamba-2 (Dao and Gu 2024).

3 RecurrentGPT

I am not sure what RecurrentGPT (Zhou et al. 2023) is. Let us find out

GitHub - aiwaves-cn/RecurrentGPT

RecurrentGPT replaces the vectorized elements (i.e., cell state, hidden state, input, and output) in a Long-short Term Memory RNN (LSTM) with natural language (i.e., paragraphs of texts), and simulates the recurrence mechanism with prompt engineering.

At each timestep \(t\), RecurrentGPT receives a paragraph of text and a brief plan of the next paragraph, which are both generated in step \(t − 1\). It then attends to the long-term memory, which contains the summaries of all previously generated paragraphs and can be stored on hard drives, and relevant paragraphs can be retrieved with semantic search.

RecurrentGPT also maintains a short-term memory that summarizes key information within recent timesteps in natural language and is updated at each time step. RecurrentGPT combines all aforementioned inputs in a prompt and asks the backbone LLM to generate a new paragraph, a short plan for the next paragraph, and updates the long-short term memory by rewriting the short-term memory and appending the summary of the output paragraph to the long-term memory.

4 Practicalities

For you and me, see AI democratizateion.

5 References

Cai, Zhu, Wang, et al. 2024. MambaTS: Improved Selective State Space Models for Long-Term Time Series Forecasting.”
Dao, and Gu. 2024. Transformers Are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality.”
Das, Kong, Sen, et al. 2024. A Decoder-Only Foundation Model for Time-Series Forecasting.” In.
Gu, and Dao. 2023. Mamba: Linear-Time Sequence Modeling with Selective State Spaces.”
Hu, Baumann, Gui, et al. 2024. ZigMa: A DiT-Style Zigzag Mamba Diffusion Model.”
Katharopoulos, Vyas, Pappas, et al. 2020. Transformers Are RNNs: Fast Autoregressive Transformers with Linear Attention.” arXiv:2006.16236 [Cs, Stat].
Nishikawa, and Suzuki. 2024. State Space Models Are Comparable to Transformers in Estimating Functions with Dynamic Smoothness.”
Patro, and Agneeswaran. 2024. SiMBA: Simplified Mamba-Based Architecture for Vision and Multivariate Time Series.”
Vardasbi, Pires, Schmidt, et al. 2023. State Spaces Aren’t Enough: Machine Translation Needs Attention.”
Waleffe, Byeon, Riach, et al. 2024. An Empirical Study of Mamba-Based Language Models.”
Wang, Gangavarapu, Yan, et al. 2024. MambaByte: Token-Free Selective State Space Model.”
Zeng, Chen, Zhang, et al. 2023. Are Transformers Effective for Time Series Forecasting? In Proceedings of the AAAI Conference on Artificial Intelligence.
Zhai, Talbott, Srivastava, et al. 2021. An Attention Free Transformer.”
Zhou, Jiang, Cui, et al. 2023. RecurrentGPT: Interactive Generation of (Arbitrarily) Long Text.”