吴忠躺衫网络科技有限公司

0
  • 聊天消息
  • 系統消息
  • 評論與回復
登錄后你可以
  • 下載海量資料
  • 學習在線課程
  • 觀看技術視頻
  • 寫文章/發帖/加入社區
會員中心
創作中心

完善資料讓更多小伙伴認識你,還能領取20積分哦,立即完善>

3天內不再提示

深入淺出理解PagedAttention CUDA實現

深度學習自然語言處理 ? 來源:PaperWeekly ? 2024-01-09 11:43 ? 次閱讀

vLLM 中,LLM 推理的 prefill 階段 attention 計算使用第三方庫 xformers 的優化實現,decoding 階段 attention 計算則使用項目編譯 CUDA 代碼實現。具體代碼在 vllm 的 csrc/attention/attention_kernels.cu 文件里,開發者洋洋灑灑寫了八百多行 CUDA 代碼。

Attention 計算時使用頁式(paged)管理 KVCache 用于增加服務吞吐率,但對延遲有負面影響,因此高效的 PA 實現方法,利用頁式內存管理同時盡量降低其負面影響,對框架的綜合性能表現至關重要。

本文章將描述 PA CUDA Kernel 的實現細節,這些細節是公開的論文和博客所不涉及的,但卻對框架的速度至關重要。另外,PA 實現改編自 FasterTransformers 某個版本的 MHA 實現,NV 原始版本對 GPU 特性的運用也是相當老道的,值得大家借鑒。

vLLM 中有兩個版本 PA,使用一個簡單的啟發式方法來決定是使用 V1 還是 V2 版本。V1 是本文介紹的版本,改編自 FasterTransformers 的 MHA 實現。V2 是參考 FlashDecoding 方式進行實現,對 sequence 維度進行切分以增加并行粒度,關于 FlashDecoding 可以參考本人知乎文章。V1 適合長度小于 8192 或者 num_seqs * num_heads>512 的情況。

參數定義和數據結構

num_seq:本次推理請求 sequence 數目。

num_head:Query 的 head 數目。

num_kv_heads:Key、Value 的 head 數目,對于 MHA 和 num_head 相同,如果是 GQA、MQA 則 num_kv_heads 小于 num_head。

head_size hidden dimension,特征的維度。

PA 使用 tensor 的維度信息

out [num_seqs, num_heads, head_size]

Q [num_seqs, num_heads, head_size]

KCache [num_blocks, num_kv_heads, head_size/x, block_size, x]:x 表示一個向量化的大小,如 float16 -> 16 / sizeof(float16) = 8。

VCache [num_blocks, num_kv_heads, head_size, block_size]

Paged 內存管理相關的輔助數據結構:

blk_size:也就是 block_size,是 KVCache page 的最高維,KVCache 是若干個 page 的集合,每個 page 存(blk_size, num_head,head_size)個 K、V 的元素。

head_mapping [num_heads] 用于 MQA, GQA,確定用的 KV_head

block_tables [num_seqs, max_num_blocks_per_seq] block_tables 映射表,表示每個 sequence 映射到哪幾個 block 上

context_lens [num_seqs] 用于變長

課前問題

如果你能回答以下兩個問題,那么說明你已經非常熟練地掌握了 PA 實現,并可以用批判性的眼光審閱本文,找出其中可能存在的錯誤。如果你暫時無法回答這些問題,請不要擔憂,閱讀完本文后會給你答案。

Q1:為什么 K Cache 的 layout 和 V Cache layout 不一樣?

Q2:PA 實現和 FlashAttention 有什么區別?

PagedAttention算子計算流程

首先,按照 CUDA 編程模型對任務進行并行劃分,grid 大小(num_heads, num_seqs),grid 中每個 CUDA thread block 大小(NUM_THREADS),NUM_THREADS 是常量默認為 128,也就說每個 thread block 包含 128 個線程,負責完成 output 矩陣一行(包含 head_size 個元素)結果的 attention 計算任務。thread block 中的線程進一步劃分若干個WARP。

眾所周知,WARP 是 GPU 一個基本的執行單元,由 32 個線程組成,這些線程以 SMIT 方式在硬件上同時執行相同的指令,在不同的數據上進行操作。在 PA 中比較特殊的是,warp 內 32 個線程進一步劃分為 blk_size 個 thread group,這和 paged KVCache 設計 x 息息相關的,馬上會細講。

Attention 計算 softmax(QK^T)V,一圖勝前言,后面流程介紹將圍繞下面這幅圖展開。其中 thread block, warp, thread group, thread 別用不同顏色表示。

ed093146-ae34-11ee-8b88-92fbcf53809c.png

▲ 圖1:PagedAttention CUDA計算流程

在上圖的左側部分,我們看到了 Q 矩陣,這部分描述了從顯存讀取 Q 數據到共享內存的過程。在這個過程中,一個 CUDA 線程塊會讀取圖中 Q 矩陣的一行(包含 head_size個元素)并將其存入共享內存。

這個過程是通過一個循環來實現的,在每次迭代中,每個 thread group 會讀取 16 字節的 Q 數據(例如,如果使用 float16,那么就是 8 個元素)。每個 warp 會讀取 16*blk_size 字節的 Q 數據,這些數據對應于一個 sequence 的一個 head,由 CUDA grid 索引指定。當循環訪問結束后,共享內存存儲 Q 行的一部分。如下圖所示,綠色部分表示存儲在一個線程讀入共享內存中的數據。

ed1a631c-ae34-11ee-8b88-92fbcf53809c.png

圖 1 中上面部分 K 矩陣部分描述了從顯存讀取 K Cache 到寄存器的過程。每個序列的 K Cache 包含 cxt_length * num_kv_heads * head_size 個元素,但由于采用了頁式內存管理,這些元素在內存中的存儲并不連續。每個 thread block 只負責計算一個 sequence 一個 head 的 QK^T,因此只需要 ctx_length * head_size 個 K Cache 元素。

然而,由于 ctx_length 維度的存儲是不連續的,并且以 blk_size 個 token 為粒度分布在不同的內存地址,我們需要根據query的head_idx和 seq_idx 訪問 block_table 以找到 K Cache的physical_block_num。為了方便后續的描述,我們可以將 K Cache 視為(:, head_size)的形狀,其中 head_size 個元素組成一行。

K Cache 的布局為 [num_blocks, num_kv_heads, head_size/x, block_size, x],這是為了優化寫入 shared memory 的操作。在 Q 和 K 矩陣的同一行元素被讀入寄存器并進行點乘運算后,結果需要被存入 shared memory。

如果一個 warp 中所有線程都計算 Q、K 同一行數據,會導致寫入 shared memory 的同一個位置,這將造成 warp 內不同線程順序地寫入。因此,為了優化,warp的線程最好計算 Q 和 K 的不同行數據。因此,在設計 K Cache 布局時,我們將 block_size 放在比 head_size 更低的維度。

由于 warp size 大于 block_size,我們需要將 head_size 拆分為 head_size/x 和 x 兩個維度,借 x 到最低維度,以確保每個線程讀入的數據量和計算量都足夠大。最后,每個線程組派一個線程去寫入 shared memory,這樣一個 warp 有 blk_size 個線程并行寫入 shared memory,從而增加了 shared memory 的訪問帶寬。這種設計策略是為了實現高效的并行計算和內存訪問,以提高整體的計算性能。

在代碼實現中,訪問 K 矩陣需要一個循環,該循環使得 CUDA 線程塊中的所有 warp 依次訪問 num_block 個頁面。在每次循環迭代中,每個 warp 負責訪問連續的 blk_size個K Cache 行,這涉及到的數據量為 blk_size * head_size 個元素。同時,每個 thread group 負責訪問 K Cache 的一行,將 head_size 個元素加載到自己的寄存器中。

接著,寄存器中的 Q 和 K 數據元素立即進行點乘運算,運算結果被寫入 shared memory 中。因此,線程塊的 shared memory 存儲了一行 QK^T 的結果,包含 ctx_length 個元素。這種實現方式充分利用了 CUDA 的并行計算能力,以提高數據處理的效率。

然后,thread block 對 shared memory 中元素進行 max,sum 方式 reduction,然后計算得到 softmax 結果。

圖 1 右邊 V 矩陣部分描述從顯存讀 V Cache 到寄存器過程。和 K Cache 一樣,CUDA thread block 依次訪問 num_blk 個物理塊到寄存器,每個 warp 負責 blk_size 個 token 的 page 內存,page 的真實物理地址同樣需要進行索引。

不過這里不需要以 thread group 為單位訪問 16 字節,而是每個 thread 訪問 16 字節的元素。訪問完就可以與 shared memory 的 softmax(QK^T) 中間結果對應位置 16 字節的數據進行點乘,得到一個 float 結果,寫到 output 對應位置中。

為什么V Cache的layout是 [num_blocks, num_kv_heads, head_size, block_size],和 K Cache layout 不一樣?這是因為 V 要去做點乘的對象在shared memory,只需要讀,不涉及并行寫的問題。

和 FlashAttention(FA)有什么不同?結合我的圖和中間 FAv2 的流程圖對比就一目了然了。FA 用了兩層循環,每次寫一個 Tile 的 output tensor,而 PA 一直只有一層循環,每次寫一行 output tensor。因為每次都有整行的 QK^T 中間結果,不需要 online softmax 這種花哨技巧。

ed257e1e-ae34-11ee-8b88-92fbcf53809c.png

PAv1的問題

以我粗淺的理解指出幾點 vLLM PAv1 的問題。一、和 MHA 相比,MQA 和 GAQ 沒有減少對 KV Cache 的讀寫次數。讀 K、V Cache 時候只是做了一個 head_idx 的轉換,會重復從顯存讀相同的 head。二、對于 seq length 很長情況沒法適應,因為沒有沿著 ctx_length 或者 batch 維度做切分。這點 FlashAttention 和 FlashDecoding 就做了,因此 PAv2 借鑒了 FA 的切分思想。

總結

vLLM 的 paged attention v1 實現繼承自 FasterTransformers MHA 實現,它和 FlashAttention 的并行任務劃分方式不同。其中對 KVCache layout 的設計比較巧妙,充分利用了 shared memory 寫帶寬,是一種常用 CUDA 編程技巧。







審核編輯:劉清

聲明:本文內容及配圖由入駐作者撰寫或者入駐合作網站授權轉載。文章觀點僅代表作者本人,不代表電子發燒友網立場。文章及其配圖僅供工程師學習之用,如有內容侵權或者其他違規問題,請聯系本站處理。 舉報投訴
  • 寄存器
    +關注

    關注

    31

    文章

    5363

    瀏覽量

    121171
  • Cache
    +關注

    關注

    0

    文章

    129

    瀏覽量

    28433
  • 內存管理
    +關注

    關注

    0

    文章

    168

    瀏覽量

    14190
  • MQA
    MQA
    +關注

    關注

    0

    文章

    3

    瀏覽量

    6058

原文標題:vLLM皇冠上的明珠:深入淺出理解PagedAttention CUDA實現

文章出處:【微信號:zenRRan,微信公眾號:深度學習自然語言處理】歡迎添加關注!文章轉載請注明出處。

收藏 人收藏

    評論

    相關推薦

    深入淺出AVR(傻孩子)

    本帖最后由 eehome 于 2013-1-5 09:56 編輯 深入淺出AVR(傻孩子)
    發表于 06-29 15:43

    深入淺出AVR

    深入淺出AVR,一本書。
    發表于 07-15 12:02

    深入淺出玩轉FPGA

    深入淺出玩轉FPGA
    發表于 07-21 09:21

    深入淺出ARM7

    深入淺出ARM7
    發表于 08-18 10:12

    HDMI技術深入淺出

    HDMI技術深入淺出
    發表于 08-19 10:52

    深入淺出Android

    深入淺出Android
    發表于 08-20 10:14

    深入淺出Android

    深入淺出Android
    發表于 04-26 10:48

    深入淺出安防視頻監控系統

    深入淺出安防視頻監控系統深入淺出安防視頻監控系統
    發表于 05-22 19:28

    深入淺出AVR

    深入淺出AVR
    發表于 08-23 10:10

    深入淺出數據分析

    深入淺出數據分析,有需要的朋友下來看看。
    發表于 01-15 14:22 ?0次下載

    深入淺出談多層面板布線技巧

    深入淺出談多層面板布線技巧
    發表于 12-13 22:20 ?0次下載

    深入淺出Android—Android開發經典教材

    深入淺出Android—Android開發經典教材
    發表于 10-24 08:52 ?15次下載
    <b class='flag-5'>深入淺出</b>Android—Android開發經典教材

    深入淺出數字信號處理

    深入淺出數字信號處理
    發表于 12-07 20:14 ?555次閱讀

    深入淺出理解阻抗匹配

    深入淺出理解阻抗匹配
    的頭像 發表于 02-03 15:14 ?4226次閱讀

    深入淺出學習250個通信原理資源下載

    深入淺出學習250個通信原理資源下載
    發表于 04-12 09:16 ?28次下載
    百家乐官方网站| 百家乐官网发牌盒子| 百家乐真人视屏游戏| 娱乐城注册体验金| 皇室百家乐官网的玩法技巧和规则| 免费百家乐游戏下| 安庆市| 线上百家乐赢钱| 汤阴县| 澳门百家乐网站bt| 阜平县| 百家乐游戏源码手机| 百家乐官网注册下注平台| 百家乐平注胜进与负追| 唐山市| 百家乐的玩法技巧和规则| 百家乐官网最新投注方法| 缅甸百家乐视频| 百家乐官网游戏源码手机| 大发888出纳柜台| 手机百家乐官网的玩法技巧和规则| 大发888客户端de 软件| 大三巴百家乐官网的玩法技巧和规则 | 真人百家乐官网游戏网址| 大发888体育场| 网络百家乐官网会输钱的多吗 | 百家乐官网赌的技巧| 百家乐游戏接口| 最好的百家乐官网博彩网站 | 鸿发| 信誉百家乐博彩网| 淘金百家乐官网现金网| 沙龙百家乐代理| 机械百家乐官网技巧| 百乐门娱乐城注册| 百家乐讯特| 百家乐官网威尼斯人| 谈大发888风水和运气| 百家乐论坛百科| 百家乐官网视频双扣游戏| 威尼斯人娱乐城|