当前,Flash Attention 已成为标配,但它并不会显式计算并存储注意力权重 (attention weights);因此,如果大家想要获得用于衡量 lazy ratio 的注意力信息,就必须重新计算注意力权重,这会带来不可忽视的额外开销。
解决方案:为避免重复计算,大家借鉴了 online softmax 的思路,利用 Flash Attention 在计算过程中生成的 LSE(log-sum-exp)作为 lazy ratio 的分母。更值得注意的是,大家惊喜地发现分子的计算复杂度仅为 O (1),而若重新计算则需要 O (seq_len),因此这种方法有效地避免了大规模的重复开销。具体算法如下:
问题 2:prefilling 阶段的峰值内存
若等到 prefilling 结束后才根据各层的 lazy ratio 进行识别和转换,那么整个 prefilling 阶段所需的内存峰值并没有减少。