【学术速递】SageAttention:低比特量化Attention,实现推理加速!高达2倍,视频、图像、文本生成领域均可用~
原文标题:又快又准,即插即用!清华8比特量化Attention,两倍加速于FlashAttention2,各端到端任务均不掉点!
原文作者:机器之心
冷月清谈:
- 清华大学团队提出了SageAttention,一种8Bit的Attention量化技术,可在不损失精度的情况下将推理速度提高2倍或更高。
- SageAttention通过平滑K矩阵、对Q和K采用分块INT8量化,以及对P和V使用FP16矩阵乘法累加器来解决直接低比特量化Attention导致的精度问题。
- 在视频、图像和文本生成等大模型上,SageAttention已被证明即使在序列长度很长时也能实现即插即用的加速,而不会降低端到端精度。
**具体要点**
- 直接对注意力运算中的矩阵进行低比特量化会导致精度损失,主要是由于矩阵K的异常值分布和矩阵P和V的可变量化精度。
- SageAttention通过平滑K矩阵,并对Q和K采用分块INT8量化以及对P和V使用FP16累加器,有效解决了这些精度问题。
- SageAttention的推理速度比FlashAttention2和xformers分别提高了2.1倍和2.7倍,并且在各种大模型中端到端精度无损失。
怜星夜思:
2、SageAttention在哪些实际应用场景中表现出色?
3、使用SageAttention需要做什么准备工作?
原文内容
AIxiv专栏是机器之心发布学术、技术内容的栏目。过去数年,机器之心AIxiv专栏接收报道了2000多篇内容,覆盖全球各大高校与企业的顶级实验室,有效促进了学术交流与传播。如果您有优秀的工作想要分享,欢迎投稿或者联系报道。投稿邮箱:[email protected];[email protected]
论文第一作者张金涛来自清华大学计算机系,论文通讯作者陈键飞副教授及其他合作作者均来自清华大学计算机系。
大模型中,线性层的低比特量化(例如 INT8, INT4)已经逐步落地;对于注意力模块,目前几乎各个模型都还在用高精度(例如 FP16 或 FP32)的注意力运算进行训练和推理。然而,随着大型模型需要处理的序列长度不断增加,Attention(注意力运算)的时间开销逐渐成为网络优化的主要瓶颈。
为了提高注意力运算的效率,清华大学陈键飞团队提出了 8Bit 的 Attention(SageAttention)。实现了 2 倍以及 2.7 倍相比于 FlashAttention2 和 xformers 的即插即用的推理加速,且在视频、图像、文本生成等大模型上均没有端到端的精度损失。
-
论文标题:SageAttention: Accurate 8-Bit Attention for Plug-and-play Inference Acceleration
-
论文链接:https://arxiv.org/abs/2410.02367
-
开源代码:https://github.com/thu-ml/SageAttention
-
大多视频、图像生成模型中,矩阵 K 表现出了极强的通道维度的异常值分布,直接使用 INT8 或者 FP8 数据类型对其进行量化会导致巨大的误差。
-
在所有模型中,对矩阵 P, V 进行量化不能保证一个模型中所有层的精度。下表展示了对 P, V 量化后,Llama2-7B 和 Unidiffuser 模型所有层中,最差情况的层对应的量化注意力的准确度,(该准确度为量化注意力相比全精度注意力的误差),可以发现不管对 P, V 矩阵进行何种 8Bit (INT8,E4M3,E5M2)量化,总有些层的准确率非常差,导致了端到端效果的下降。
-
对 K 进行平滑处理。SageAttention 采用了一个简单但非常实用的方法来消除矩阵 K 的异常值:K = K – mean (K) 其中 mean (K) 是沿着通道维度求平均值。这个简单的做法不仅不会影响注意力计算的正确性 Softmax (QK^T) = Softmax (Q (K-mean (K))^T) ;且对整个 Attention 速度的影响只有 0.2%;同时还保证了量化后的注意力运算的精度:
-
对 Q, K 进行分块 INT8 量化。对于矩阵 Q, K,SageAttention 采用了以 FlashAttention 的分块大小为粒度的 INT8 量化。这是因为:1. 对 Q, K 矩阵进行 INT8 量化相比于进行 FP8 量化,注意力的精度更高。2. 在一些常用卡上,比如 RTX4090,INT8 矩阵乘法(INT32 为累加器)的速度是 FP8(FP32 为累加器)的两倍。
-
对 P, V 采用 FP16 数据类型的矩阵乘法累加器。对于矩阵 P, V,SageAttention 采用了保留 P, V 为 FP16 的类型,但进行矩阵乘法时采用 FP16 数据类型的累加器。这是因为:1. PV 矩阵乘法的数值范围始终在 FP16 的表示范围内,且经过大量实验验证,FP16 作为累加器的数据类型不会带来任何精度损失(见下表)。2. 在一些常用卡上,比如 RTX4090,以 FP16 为累加器数据类型的矩阵乘法的速度是 FP32 作为累加器的两倍。
© THE END
转载请联系本公众号获得授权
投稿或寻求报道:[email protected]