欢迎您访问 最编程 本站为您分享编程语言代码,编程技术文章!
您现在的位置是: 首页

深入解读FlashAttention2原理,并探索其在AIGC领域的快速实战应用

最编程 2024-02-16 08:59:25
...



FlashAttention-2提出后,便得到了大量关注。本文将具体讲述FlashAttention-2的前世今生,包括FlashAttention1&2的原理解析、加速效果比较以及面向AIGC的加速实践,在这里将相关内容与大家分享~


引言

将 Transformers 扩展到更长的序列长度一直是过去几年的一个热点问题,这将有助于提高语言建模和高分辨率图像理解的能力,也有利于音频和视频生成方面的新应用场景研发。Attention层是扩展到更长序列的主要瓶颈,因为它的运行时间和内存占用是序列长度的二次方。使用近似计算的Attention方法,可以通过减少FLOP计算次数、甚至于牺牲模型质量来降低计算复杂性,但通常无法实现大比例的加速。


由斯坦福大学提出的FlashAttention方法,让使用更长sequence计算Attention成为可能,并且通过线性级别的增长来节省内存以及加速计算。因为FlashAttention没有进行近似计算,所以也没有精度损失。然而,FlashAttention的实际速度仍然和理论上的运算速度差距较大,仅达到理论最大 FLOPs/s 的 25-40%。效率低下的原因主要是不同线程块和warp之间的工作分区不理想,导致低占用率或不必要的共享内存读/写。为此,2023年7月,论文作者进一步提出了FlashAttention-2,实现了Attention计算速度的大幅度提升。


FlashAttention

  主要内容


FlashAttention主要关注IO-aware,进一步优化GPU显存的读写效率。这是一种 IO 感知的精确Attention算法,它使用tiling(这里可以理解为分块)来减少 GPU 高带宽内存 (HBM) 和 GPU 片上 SRAM 之间的内存读/写次数。这里的HBM可以理解为显存,SRAM可以理解为cache。通过测试IO复杂性,相比标准 Attention,FlashAttention需要更少的 HBM 访问,并且对于不同的SRAM 大小来说都是有效的。除此以外,FlashAttention还可以扩展到block-sparse attention,产生比任何现有近似注意力方法更快的近似注意力算法。

FlashAttention与 MLPerf 1.1 训练速度相比,对于BERT-large(序列长度 512)实现端到端wall-clock加速15%,对于GPT-2(序列长度 1K)加速 3 倍。FlashAttention 和block-sparse FlashAttention 可在 Transformers 中实现更长的上下文,从而产生更高质量的模型,GPT-2 上的困惑度提升0.7,长文档分类的test结果提高 6.4 个点。

  主要操作



背景知识:

上图的左图,表示存储结构,可以简单理解为:SRAM表示缓存,HBM表示显存,DRAM表示内存。


  • tiling


在不访问整个输入的情况下优化attention计算,并减少相关计算量。重构attention计算,将输入分割成块,并对分块进行多次传递,从而逐步执行attention计算(该步骤称为tiling)。



如上图所示,FlashAttention 使用tiling来防止在相对较慢的 GPU显存上实现大型 ???? × ???? 注意力矩阵(虚线框)计算。在外部循环(红色箭头)中,FlashAttention 循环遍历 K 和 V 矩阵块,并将它们加载到快速片上 SRAM。在每个块中,FlashAttention 循环遍历 Q 矩阵块(蓝色箭头),将它们加载到 SRAM,并将注意力计算的输出写回 HBM。



将输入Q、K、V矩阵分成很多块,将它们从较慢的HBM加载到较快的SRAM,然后在SRAM计算关于这些块的注意力输出。对每个块的计算结果缩放之后进行add操作,则得到正确的结果,具体伪代码如图:


  • recomputing


FlashAttention不专门存储用于后向计算的大型中间计算结果。在SRAM中存储前向计算中的 softmax 归一化因子,以便在后向传递计算梯度的时候快速得到中间结果,这比从 HBM 读取中间计算结果的标准方法更快。

FlashAttention不存储前向计算中????(????2)复杂度的中间值,但是后向传递通常需要矩阵 S, P ∈ R???? ×???? 来计算相对于 Q、K、V 的梯度。通过存储输出 O 和 softmax 归一化统计量 (????, ℓ),则可以在 SRAM 中的 Q、K、V 块的后向计算中快速重新计算注意力矩阵 S 和 P。这可以看作是checkpoint的一种形式。

如下图所示,FlashAttention由于tiling分块操作和recomputing操作,增加了一些计算次数。但是还是通过使用SRAM减少了显存占用,通过减少hbm访问次数加快了attention计算。


  Block-Sparse FlashAttention



论文还提出了Block-sparse FlashAttention,其IO复杂度比FlashAttention小,与稀疏度成正比。非0矩阵越少(即0矩阵越多),稀疏化来压缩数据的空间就越大,block-sparse加速就越明显。上图中,IO复杂度与稀疏性成正比,随着稀疏性的增加(非0矩阵增加),Block-sparse FlashAttention的运行时间成比例地提高。

  小结


总的来说,FlashAttention有如下优点:

  1. hbm访问次数降低,所以计算更快

  2. 在sram中计算attention,并对于后向计算提前保留中间结果,所以显存占用更少

  3. 可以使用更长的sequence,使得模型训练效果更好

  4. 对于attention计算,加速明显。如果加上稀疏化处理,速度会更快。


FlashAttention-2


  主要内容


FlashAttention 的整体速度仍然和单独进行矩阵乘法 (GEMM) 的运算速度差距较大,仅达到理论最大 FLOPs/s 的 25-40%。作者观察到效率低下的原因是不同线程块和warp之间的工作分区不理想,导致低占用率或不必要的共享内存读/写。最新提出 FlashAttention-2,通过更好的工作分区来解决这些问题,主要包含的操作:1.调整算法以减少非矩阵乘运算的计算次数。2.跨不同线程块进行并行化注意力计算。3.在每个线程块内, 在 warp 之间优化工作分配以减少共享内存的通信。

与FlashAttention 相比,FlashAttention-2速度提高了约 2 倍,达到 A100 上理论最大 FLOPs/s 的 50-73%,接近 GEMM 操作的效率。根据经验验证,当使用端到端来训练 GPT 式模型时,FlashAttention-2 的训练速度高达每 A100 GPU 225 TFLOPs/s(模型 FLOPs 利用率为 72%)。不同设置(有或没有causal mask、不同头部尺寸)的测试表明,FlashAttention-2 比 FlashAttention 实现了约 2 倍的加速,在前向传递中达到理论最大吞吐量的 73%,在后向传递中达到理论最大吞吐量的 63%。

  主要操作


  • 减少非矩阵运算


调整算法以减少非 matmul(矩阵乘法) FLOP 的数量,同时不改变输出。虽然非 matmul FLOP 只占总 FLOP 的一小部分,但它们由于 GPU 具有专门的矩阵乘法单元,非矩阵乘法的运算需要更长的时间来执行,矩阵乘法吞吐量可以比非矩阵乘法吞吐量高出16倍。因此,减少非 matmul FLOP 并尽可能多的进行 matmul FLOP 非常重要。

背景知识:
吞吐量是指单位时间内完成的任务数量或数据处理量。在这个上下文中,吞吐量指的是执行矩阵乘法操作时的性能表现,以及执行其他非矩阵乘法操作时的性能表现。这句话的意思是,执行矩阵乘法操作时,系统能够以每单位时间处理更多的任务或数据,其数量可以高达非矩阵乘法操作时的16倍。这表明矩阵乘法操作在性能上比其他操作更加高效。


Forward pass:优化qkv的softmax计算中非矩阵运算

Backward pass:FlashAttention-2的后向传递与FlashAttention几乎相同,主要区别在于需要进行梯度计算与更新。这里做了一个小调整,只使用求和结果????,而不是 softmax 中的行式最大值和行式指数和。


  • 增加并行比例


除了batchsize维度和head数目维度,还在序列长度维度上对前向传播和反向传播进行并行化处理,提高并行性。在序列较长的情况下,提高GPU资源的占用率。FlashAttention对于batchsize和head数目进行并行化处理,FlashAttention2基于序列长度进行并行化。当批量大小和head数量较小时,序列长度上增加的并行性有助于提高占用率(正在使用的 GPU 资源的比例),从而在这种情况下实现加速。


Forward pass:对批量维度和头数维度进行并行化,如 FlashAttention 中所做的那样。对于外循环(在序列长度上),将它们调度到不需要彼此通信的不同线程块上,每个工作线程负责关注矩阵的一行block块。外循环每次处理一行block,内循环每次处理这一行中的一列block,这和FlashAttention处理方式是不同的。



Backward pass:不同列块之间唯一共享的计算是算法 2 中更新的dQ,其中我们需要将 dQ从 HBM 加载到 SRAM,然后在片上通过 dQ更新,并写回 HBM。我们使用原子添加在不同线程块之间进行通信以更新 dQ。我们也在序列长度维度上进行并行化,并为后向传递的每一列block块安排 1 个工作线程(和前向传递是反过来的)。


  • 在warp上优化工作划分


在一个注意力计算的block内,在一个thread block的不同warp之间优化工作划分,以减少通信和共享内存的读/写。



在每个线程块内,我们也必须决定如何在不同的 warp 之间划分工作。我们通常每个线程块使用 4 或 8 个 warp,分区如上图所示。



Forward pass:对于每个块,FlashAttention 将 K 和 V 分割到 4 个 warp 上,同时保持 Q 可被所有 warp 访问。每个warp相乘得到 QK⊤ 的slice,然后它们需要与 V 的slice相乘并进行通信以将结果相加。这称为“split-K”方案。然而,这是低效的,因为所有 warp 都需要将其中间结果写入共享内存,进行同步,然后将中间结果相加这些共享内存读/写会减慢 FlashAttention 中的前向传播速度。在 FlashAttention-2 中,我们将 Q 分成 4 个经线,同时保持所有经线均可访问 K 和 V。在每个扭曲执行矩阵乘法以获得 QK⊤ 切片后,它们只需与共享的 V 切片相乘即可获得相应的输出切片。warp 之间不需要通信。共享内存读/写的减少可以提高速度


背景知识:
warp:由多个thread组成,是编程层面的概念。
flash1:k和v被分为4个不同的warp,q和k计算、再和v计算,每一次计算的中间结果都要写入共享内存,并在之后被读取。这样就增加了共享内存的读写次数、拖慢了速度。
flash2:将q分为4个不同的warp,然后计算qk、计算v。但是这里k和v不需要通信,所以计算v的时候,不需要新的内存读写。这样就减少了读写次数、加快了程序。

Backward pass:对于后向传递,我们选择对warp进行分区以避免“split-K”方案,从而减少共享内存的读/写次数,并再次提高速度。由于所有不同输入和梯度 Q、K、V、O、dO、dQ、dK、dV 之间的依赖性更加复杂,它需要一些同步操作。

  小结


FlashAttention-2可以加速attention计算。测量FlashAttention-2 在不同序列长度上的运行时间,并与 PyTorch、FlashAttention 和 Triton 中的 FlashAttention 中的标准实现进行比较。FlashAttention-2 比 FlashAttention 快 1.7-3.0 倍,比 Triton 中的 FlashAttention 快 1.3-2.5 倍,比标准注意力实现快 3-10 倍。

FlashAttention-2可以加速端到端训练。当使用端到端在 2k 或 8k 序列长度上训练大小为 1.3B 和 2.7B 的 GPT 型模型时,FlashAttention-2 与 FlashAttention 相比可实现高达 1.3 倍的加速,与基线相比可实现 2.8 倍的加速 没有FlashAttention。每个 A100 GPU 的 FlashAttention-2 速度高达 225 TFLOPs/s(模型 FLOPs 利用率为 72%)。

FlashAttention-2 比 FlashAttention 快 2 倍,可以用之前训练 8k 上下文模型的时间,来训练具有 16k 更长上下文的模型。使用更长的context训练模型,可以更好理解长篇书籍和报告、高分辨率图像、音频和视频。

FlashAttention-2加速实践


  时间与显存的优化效果


对于qkv计算,比较FlashAttention2与custom pytorch、xformers(FlashAttention1)的时间与显存消耗。如果只考虑QKV计算,flash attention2耗时是xformers(flash attention1)的一半,内存节省也更多一些。


flash attention2耗时是xformersflash attention1)的一半,内存节省也更多一些