不是RNN的锅!清华团队深入分析长上下文建模中的状态崩溃,Mamba作者点赞

新车上市作者:wang1232024-12-10更新:

新智元报道

编辑:alan

【新智元导读】RNN模型在长上下文中表现不佳?近日,来自清华的研究团队对此进行了深入的实验分析,结果表明:不是RNN的锅。

与Transformer相比,RNN模型的一大优势是应对长序列的能力。

比如Mamba,内部状态大小始终保持不变,计算随序列长度线性增长,吃得多,消化快。

理论虽如此,但实际情况却是,目前的这些RNN模型在长上下文中的有效性并不能令人满意。

为啥会这样?空有效率但实际上能力不行?

近日,来自清华的研究团队对此进行了深入的实验研究:

论文地址:https://arxiv.org/pdf/2410.07145v1

文章表明,Mamba这类RNN模型在长上下文中主要面临两个问题:

一是无法推断比训练长度更长的输入,原因是较短的训练数据导致了循环状态过拟合;

二是内存容量的上限,由于模型无法有效遗忘很久以前的信息,导致新的信息存不进来了。

——这俩问题明显不是RNN的锅。

而经过研究人员的对症下药,Mamba-2(370M)在256K上下文长度上达到了近乎完美的密钥检索精度。

所以结论就是,Mamba yes!「RNN神教」前景一片光明!

对此,Mamba的作者Albert Gu点赞转发,并发表了相当详细的见解:

「这是一篇很棒的论文(名字也很棒)—— 关于状态空间模型(SSM)的状态容量和长上下文能力的巧妙实验。」

令人惊讶的是,对于每个状态大小 M,当训练上下文长度达到或超过某个临界值 K 时,都会出现一个转折点,在这个点上 SSM 就能够稳健地实现长度泛化。 这是因为当上下文长度小于 K 时,循环状态没有被充分利用,导致模型在训练期间会「过拟合」。但一旦通过足够长序列的训练使模型的状态容量得到充分利用,它就会自动获得泛化能力。 值得注意的是,K 与 M 竟然呈线性关系!—— 这表明每个 token 可能存在某种固有的信息含量(即存在一个值 B,使得上下文中的每个 token 对应 B 字节的循环状态)。这个 B 值可能是由模型架构决定的?

「反过来说,过分担心循环模型的长度泛化问题可能是一个误区。我们无需设计新机制或特殊的缓解措施:只需要在更长的序列上训练(因为是线性时间复杂度,所以不会增加计算开销!),就能获得更好的泛化效果。」

最后,Albert Gu用一句话总结:要让你的Mamba吃得饱饱的,它就能发挥出最佳状态!

喂饱你的Mamba

先来复习一下基础知识。

本文以Mamba2作为主要研究对象,内部的计算表示为下图中的并行结构:

整体的输入输出遵循SSM(也即RNN)的形式:

而把上图中模块内部所有的计算写出来,就是下面这一坨公式:

之前提到的两个问题,核心在于模型的内部状态,也就是ht的表现。

所以下面在探索问题和解决方案时,咱们可以重点关注这些公式中,与ht计算相关的参数。

之前有研究表明,当上下文长度超过其训练长度时,Mamba-1和RWKV-4的性能会严重下降。

顺着这个思路,研究人员在两个方向上进行了实验分析:状态崩溃(STATE COLLAPSE)和容量上限(STATE CAPACITY)。

状态崩溃

状态崩溃(SC)指的是,RNN模型在输入上表现出异常行为的时间比训练期间看到的时间更长的现象。

上图展示了Mamba-2和RWKV-6在训练长度之外的语言建模损失。为了可控性和合成任意长度的提示,这个损失是在仅由「\n」字符组成的提示上计算的(称为「newlines」提示)。

结果表明,当上下文长度远大于其训练长度时,两个RNN的性能都会严重下降,最后就跟瞎猜差不多了。

语言建模可能无法反映下游能力,上图给出了Mamba-2(在8K上下文窗口上训练)在密钥检索任务上的评估结果。

我们可以发现,Mamba-2在8K上下文中具有近乎完美的检索准确性,但在序列长度超过16K后就没法看了,无论模型参数量大小。

从上面的公式来看,这种结果可能出人意料,因为内部状态ht的更新应该具有稳定的指数内存衰减,即对于最后k个token具有良好的检索准确性。

问题出在哪里?

由于递归状态的维度不会随时间而变化,因此状态崩溃期间行为的急剧变化一定是状态值变化的结果。

作者对Mamba-2 370M中每一层的递归状态进行了统计,发现当上下文长度超过训练长度时,一些头部的平均值和方差会急剧变化:

图5显示了模型第38层第2个头的状态,在t=20K时方差爆炸。从中可以发现这种方差爆炸在很大程度上可以归因于少数异常通道,其余大多数通道则相对稳定。

分析一下公式,与ht计算有关的∆t、Bt和xt:

如上图所示,虽然三者都是输入的函数,但xt相对稳定,而Bt比∆t更早发生爆炸,进一步探索还能发现生成∆t和Bt的卷积权重明显更大。

作者认为,产生SC的原因是,对于训练长度来说,状态容量过大,模型能够实现强大的语言建模性能,而无需学习如何忘记。

上图显示了第一个token在不同时间步的内存强度,作者发现爆炸的头(第38层的第2、4、7个头)强烈倾向于在训练长度内保留所有信息,在t=8K时内存强度超过0.8。

解决方案

为了缓解SC,使模型沿序列长度更好地泛化,作者提出了3种解决方案,总的思想是修改状态的update规则来避免其溢出。

Method 1: Forget More and Remember Less

通过增加状态衰减量(忘记更多)或减少输入信息的数量(记住更少)来减少SC,作者选择干预Bt和αt(分别控制输入强度和内存衰减强度)。

Method 2: State Normalization

在每次更新后对状态进行归一化,以确保状态的范数始终低于阈值:

PS:这种方式会将模型转换为非线性RNN,无法以与原始模型相同的方式并行化,预填充速度要慢得多。

Method 3: Sliding Window by State Difference

利用状态ht可以写为加权和的形式,来模拟滑动窗口机制,无需在每一步都从窗口的开头重新处理。

此方法适用于所有可以写成加权和的RNN,包括RWKV 5和6、RetNet、GLA等。尽管会使生成的计算和内存成本翻倍,但仍然是一个可以接受的权衡,因为RNN的生成成本比Transformer低很多。

以上3个是不需要训练的方案,而基于SC是由状态参数过拟合引起的假设,我们也可以尝试使用超过状态容量的序列长度来训练模型。

容量上限

根据以上的讨论,当且仅当训练长度包含的信息少于状态容量时,才会发生SC,所以我们可以通过实验间接估计模型的状态容量。

研究人员训练了多个具有不同状态大小和训练长度的Mamba-2,并将SC未发生的最小训练长度视为状态容量。

实验数据选择RedPajama-V2,一个从CommonCrawl中提取的30T token的开放数据集,进行去重以确保数据质量。

在评估过程中,对长度超过16K token的文档进行抽样,如果不够长,则对其进行拼接。

研究人员试验了具有不同状态大小的模型配置,包括来自Mamba-2官方checkpoint的三个预训练模型,大小分别为130M、370M和780M,另外3个模型(36M、47M、85M)则从头开始训练。

实验结果

上图展示了在Mamba-2 780M上无训练长度泛化方法的结果。我们可以看到,虽然LongMamba大大提高了模型的长度泛化性(3倍以上),但它在较短的序列上会导致明显更大的困惑度,并且仍然不可避免地表现出SC。

相比之下,本文的所有的方法都成功地抑制了SC,使模型能够泛化到超过64K个token。

三种方案中,状态归一化在较短序列上的性能大大低于其他方法,这可能是因为归一化折叠状态会改变heads之间的规范比率,破坏了学习机制。

上图显示了Mamba-2在语言建模和密钥检索方面的状态容量。两个图中最右边的数据点对应于Mamba-2 370M。

左边的图可以拟合出一个线性关系,而右边的图则表明Mamba-2在密钥检索方面的容量与状态大小呈指数级关系。

这是因为上下文中的信息量不会随着其长度的增加而增加。换句话说,模型存储了恒定数量的信息,而状态的组合数量随着元素数量呈指数增长。

参考资料:

https://arxiv.org/abs/2410.07145v1

https://x.com/_albertgu/status/1852011550711632289