0x0. 前言
本文解析一下mlc-llm(https://github.com/mlc-ai/mlc-llm)對大模型推理的流程以及使用的圖優化,算子優化策略。mlc-llm的模型部署流程可以查看官方文檔:https://mlc.ai/mlc-llm/docs/ ,也可以參考我前段時間寫的這篇MLC-LLM 部署RWKV World系列模型實戰(3B模型Mac M2解碼可達26tokens/s) 。
此外,閱讀mlc-llm的代碼還需要理解一些TVM Unify的一些基礎概念,可以參考TVM 學習指南(個人版) ,Relax: TVM 的下一代圖層級 IR,新一代深度學習編譯技術變革和展望等等。從 https://github.com/BBuf/tvm_mlir_learn 這里可以查看更多相關博客和資料。
在 MLC-LLM 部署RWKV World系列模型實戰(3B模型Mac M2解碼可達26tokens/s) 中提到要使用mlc-llm部署模型首先需要一個編譯過程,將原始的基于Realx搭建的模型比如RWKV和給定的device信息一起編譯為TVM中的runtime.Module(在linux上編譯的產物就是.so文件)提供mlc-llm的c++推理接口調用 。我們就從這里看起:
由于mlc-llm上游更新很快,為了準確標定代碼位置我fork了一份2023年9月17號的mlc-llm代碼 :https://github.com/BBuf/mlc-llm-code-analysis ,本文的注釋以及指出的代碼位置均以這個fork倉庫為準。
0x1. 編譯流程解析
編譯的入口在:https://github.com/BBuf/mlc-llm-code-analysis/blob/main/mlc_llm/build.py 。
這個腳本構建了一個模型build的入口,可以通過傳入不同的參數來構建不同配置的模型。參數解析和模型編譯都在 https://github.com/BBuf/mlc-llm-code-analysis/blob/main/mlc_llm/core.py 中實現,模型編譯準備(mod_transform_before_build函數)和編譯(build函數)兩個階段。在模型編譯準備階段,包含準備需要優化的算子,執行一些基礎的圖變換,針對cuda做進一步優化,做算子fuse等優化,詳細的解釋清閱讀這里的注釋:https://github.com/BBuf/mlc-llm-code-analysis/blob/main/mlc_llm/core.py#L378 。
在這之后會執行編譯過程:https://github.com/BBuf/mlc-llm-code-analysis/blob/main/mlc_llm/core.py#L378 。從這里我們可以看到,對于GPU來說使用的是默認的schedule模板,并沒有使用AutoTVM/Ansor等等調優工具,這一點是很友好的,個人猜測也是因為Transformer架構的模型是很固定的,然后優化方法也比較統一。
上面的編譯前準備和編譯都是針對IRModule來說的,那么這個IRModule是怎么來的呢?以及量化是在哪里做的?這兩個問題都是在 build_model_from_args 函數: https://github.com/BBuf/mlc-llm-code-analysis/blob/main/mlc_llm/core.py#L627 處理的,發生在 mod_transform_before_build 函數調用之前。以 RWKV 模型為例,通過這行 mod, param_manager, params, model_config = rwkv.get_model(args, config) 代碼完成了從原始的 huggingface 模型到初始的 IRModule 的轉換,在這個過程中也包含了量化。
0x2. 模型搭建解析
0x2.1 模型組件搭建
首先在 https://github.com/BBuf/mlc-llm-code-analysis/blob/main/mlc_llm/relax_model/modules.py 這里基于Relax的內部接口(relax.Expr,relax.testing.nn.Module,relax.op.xxx等等)定義了搭建LLM模型需要的一些組件比如 ModuleList,Linear,Embedding,LayerNorm,RotaryEmbedding等等。這個地方我添加了一些解釋,請點上面的源碼鏈接查看。然后這個地方需要注意2個特殊的op,第一個是來自 https://github.com/mlc-ai/relax/blob/ceaf7b0156524d30537a3de5fa30764eaff4edb8/python/tvm/relax/op/index.py#L28 的:
def?take(x:?Expr,?indices:?Expr,?axis:?Optional[int]?=?None)?->?Expr: ????return?_ffi_api.take(x,?indices,?axis)??#?type:?ignore
這個函數,實現了take的核心功能,與numpy和pytorch的take語義類似,都可以通過指定indices來從輸入張量中抽取值。主要調用了_ffi_api.take進行取值操作, 這個_ffi_api是relax底層實現, take操作的實際計算會在這里進行。這個函數被用于Embedding組件的搭建中。
另外nn.emit這個接口的作用是將一個relax.Expr表達式轉化為relax.Var變量,并保存該變量。
最后我們注意到這里搭建的Relax模塊風格和PyTorch的模塊風格基本一致,也可以看出Relax前端是不斷靠近動態圖風格,追求更佳的易用性。
0x2.2 模型搭建
首先看一些準備工作:
#?@dataclass:這個裝飾器用于指示RWKVConfig類是一個數據類。用于存儲RWKVModel的配置信息。 @dataclass class?RWKVConfig: ????"""The?configuration?class?to?store?the?configuration?of?a?`RWKVModel`.""" ????num_hidden_layers:?int?#?類中的一個屬性,用于存儲隱藏層的數量,類型為整數。 ????vocab_size:?int?#?類中的一個屬性,用于存儲詞匯表的大小,類型為整數。 ????hidden_size:?int?#?類中的一個屬性,用于存儲隱藏層的大小,類型為整數。 ????intermediate_size:?int?#?類中的一個屬性,用于存儲中間層的大小,類型為整數。 ????rescale_every:?int?=?0?#?類中的一個屬性,默認值為0,用于存儲重新縮放的頻率,類型為整數。 ????layer_norm_epsilon:?float?=?1e-5?#?類中的一個屬性,默認值為1e-5,用于存儲層歸一化的epsilon值,類型為浮點數。 ????max_sequence_length:?int?=?1024?#?類中的一個屬性,默認值為1024,用于存儲最大序列長度,類型為整數。 ????dtype:?str?=?"float32"?#?類中的一個屬性,默認值為"float32",用于存儲數據類型,類型為字符串。 ????def?__init__( ????????self, ????????num_hidden_layers:?int, ????????vocab_size:?int, ????????hidden_size:?int, ????????intermediate_size:?int, ????????rescale_every:?int?=?0, ????????layer_norm_epsilon:?float?=?1e-5, ????????context_length:?int?=?1024, ????????dtype:?str?=?"float32", ????????**kwargs, ????)?->?None: ????????self.num_hidden_layers?=?num_hidden_layers ????????self.vocab_size?=?vocab_size ????????self.hidden_size?=?hidden_size ????????self.intermediate_size?=?intermediate_size ????????self.rescale_every?=?rescale_every ????????self.layer_norm_epsilon?=?layer_norm_epsilon ????????self.max_sequence_length?=?context_length ????????self.dtype?=?dtype ????????self.kwargs?=?kwargs #?用來索引RWKV的Attention和FFN部分存儲的狀態或者叫Cache。 #?python代碼可以參考:?https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/model.py#L858-L867 class?State: ????ATT_X?=?0 ????ATT_A?=?1 ????ATT_B?=?2 ????ATT_P?=?3 ????FFN_X?=?4
這里的State是用來索引RWKV的Attention和FFN部分存儲的狀態或者叫Cache,每一個Layer有5個不同的State,并且每個State的shape都是[1, hidden_size],這里的1代表的應該是batch緯度。
#?義了一個名為_load_state的函數,它接受一個名為state的參數,類型為Expr,一個名為hidden_size的參數,類型為整數, #?一個名為dtype的參數,類型為字符串。函數的返回類型為Expr。 def?_load_state(state:?Expr,?hidden_size:?int,?dtype:?str)?->?Expr: ????#?Reuse?`attention_kv_cache_view` ????#?將外部函數vm.builtin.attention_kv_cache_view賦值給變量f_load_cache。relax.extern是一個外部函數調用的語法, ????#?它指示編譯器在編譯時將該函數調用轉換為相應的外部函數調用。 ????f_load_cache?=?relax.extern("vm.builtin.attention_kv_cache_view") ????#?使用nn.emit方法生成一個表達式對象,該表達式表示對外部函數f_load_cache的調用。 ????#?調用的參數是一個列表,包含state和R.shape([1,?hidden_size]),以及sinfo_args參數指定的一個R.Tensor對象。 ????cache?=?nn.emit( ????????relax.Call( ????????????f_load_cache, ????????????[state,?R.shape([1,?hidden_size])], ????????????sinfo_args=[R.Tensor((1,?hidden_size),?dtype)], ????????) ????) ????return?cache #?定義了一個名為_store_state的函數,它接受一個名為state的參數,類型為Expr,一個名為value的參數,類型為Expr。 def?_store_state(state:?Expr,?value:?Expr): ????#?Reuse?`attention_kv_cache_update` ????#?將外部函數vm.builtin.attention_kv_cache_update賦值給變量f_store_cache。 ????#?relax.extern是一個外部函數調用的語法,它指示編譯器在編譯時將該函數調用轉換為相應的外部函數調用。 ????f_store_cache?=?relax.extern("vm.builtin.attention_kv_cache_update") ????#?使用nn.emit方法生成一個表達式對象,該表達式表示對外部函數f_store_cache的調用。 ????#?調用的參數是一個列表,包含state和value,以及sinfo_args參數指定的一個R.Object()對象。 ????return?nn.emit( ????????relax.Call( ????????????f_store_cache, ????????????[state,?value], ????????????sinfo_args=[R.Object()], ????????) ????)
這兩個函數用來加載和存儲RWKV模型的State。接下來看一下對應 https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/model.py#L741 這里的torch.ops.rwkv.wkv_forward(1, T, C, w, u, k, v, y, aa, bb, pp) 的Relax實現,為了方便對照先貼一下原始的wkv forward cuda kernel:
?
?
template?__global__?void?kernel_wkv_forward(const?int?B,?const?int?T,?const?int?C, ???????????????????????????????const?float?*__restrict__?const?_w,?const?float?*__restrict__?const?_u,?const?F?*__restrict__?const?_k,?const?F?*__restrict__?const?_v, ???????????????????????????????F?*__restrict__?const?_y,?float?*__restrict__?const?_aa,?float?*__restrict__?const?_bb,?float?*__restrict__?const?_pp)?{ ????const?int?idx?=?blockIdx.x?*?blockDim.x?+?threadIdx.x; ????const?int?_b?=?idx?/?C; ????const?int?_c?=?idx?%?C; ????const?int?_offset?=?_b?*?T?*?C?+?_c; ????const?int?_state_offset?=?_b?*?C?+?_c; ????float?u?=?_u[_c]; ????float?w?=?_w[_c]; ????const?F?*__restrict__?const?k?=?_k?+?_offset; ????const?F?*__restrict__?const?v?=?_v?+?_offset; ????F?*__restrict__?const?y?=?_y?+?_offset; ????float?aa?=?_aa[_state_offset]; ????float?bb?=?_bb[_state_offset]; ????float?pp?=?_pp[_state_offset]; ????for?(int?i?=?0;?i? void?cuda_wkv_forward(int?B,?int?T,?int?C,?float?*w,?float?*u,?F?*k,?F?*v,?F?*y,?float?*aa,?float?*bb,?float?*pp)?{ ????dim3?threadsPerBlock(?min(C,?32)?); ????assert(B?*?C?%?threadsPerBlock.x?==?0); ????dim3?numBlocks(B?*?C?/?threadsPerBlock.x); ????kernel_wkv_forward<< >>(B,?T,?C,?w,?u,?k,?v,?y,?aa,?bb,?pp); }
這個cuda kernel里面,B表示batch_size,在mlc-llm的實現默認為1。然后T表示序列長度,C表示隱藏層緯度。然后我們就可以對應來看mlc-llm的wkv實現了。
#?定義了一個名為create_wkv_func的函數,它接受一個名為hidden_size的參數, #?類型為整數,一個名為dtype的參數,類型為字符串,一個名為out_dtype的參數,類型為字符串。 def?create_wkv_func(hidden_size:?int,?dtype:?str,?out_dtype:?str): ????@T.prim_func ????def?wkv_func( ????????k:?T.handle, ????????v:?T.handle, ????????time_decay:?T.handle, ????????time_first:?T.handle, ????????saved_a:?T.handle, ????????saved_b:?T.handle, ????????saved_p:?T.handle, ????????wkv:?T.handle, ????????out_a:?T.handle, ????????out_b:?T.handle, ????????out_p:?T.handle, ????): ????????#?設置TIR函數的屬性。這里設置了三個屬性,包括op_pattern、tir.noalias和tir.is_scheduled。 ????????T.func_attr({"op_pattern":?8,?"tir.noalias":?True,?"tir.is_scheduled":?1}) ????????#?聲明一個名為context_length的變量,類型為T.int64(),用于存儲上下文長度。 ????????context_length?=?T.int64() ????????#?創建一個名為K的匹配緩沖區,通過T.match_buffer方法匹配參數k的形狀和數據類型。 ????????#?K的形狀在原始的ChatRWKV中為B,T,C,只不過這里B=1 ????????#?這里的k就是上面cuda?kernel的_k ????????K?=?T.match_buffer(k,?(context_length,?hidden_size),?dtype=dtype) ????????#?創建一個名為V的匹配緩沖區,通過T.match_buffer方法匹配參數v的形狀和數據類型。 ????????#?這里的v就是上面cuda?kernel的_v ????????V?=?T.match_buffer(v,?(context_length,?hidden_size),?dtype=dtype) ????????#?創建一個名為TimeDecay的匹配緩沖區,通過T.match_buffer方法匹配參數time_decay的形狀和數據類型。 ????????#?這里的TimeDecay就是上面的w ????????TimeDecay?=?T.match_buffer(time_decay,?(hidden_size,),?dtype=dtype) ????????#?創建一個名為TimeFirst的匹配緩沖區,通過T.match_buffer方法匹配參數time_first的形狀和數據類型。 ????????#?這里的TimeFirst對應上面的u ????????TimeFirst?=?T.match_buffer(time_first,?(hidden_size,),?dtype=dtype) ????????#?對應kernel里面的_aa的上一個token的狀態 ????????SavedA?=?T.match_buffer(saved_a,?(1,?hidden_size),?dtype=dtype) ????????#?對應kernel里面的_bb的上一個token的狀態 ????????SavedB?=?T.match_buffer(saved_b,?(1,?hidden_size),?dtype=dtype) ????????#?對應kernel里面的_pp的上一個token的狀態 ????????SavedP?=?T.match_buffer(saved_p,?(1,?hidden_size),?dtype=dtype) ????????#?對應_aa的當前token狀態 ????????OutA?=?T.match_buffer(out_a,?(1,?hidden_size),?dtype=dtype) ????????#?對應_bb的當前token狀態 ????????OutB?=?T.match_buffer(out_b,?(1,?hidden_size),?dtype=dtype) ????????#?對應_pp的當前token狀態 ????????OutP?=?T.match_buffer(out_p,?(1,?hidden_size),?dtype=dtype) ????????#?對應kernel里面的p ????????P?=?T.alloc_buffer((hidden_size,),?dtype=dtype,?scope="local") ????????#?對應kernel里面的e1 ????????E1?=?T.alloc_buffer((hidden_size,),?dtype=dtype,?scope="local") ????????#?對應kernel里面的e2 ????????E2?=?T.alloc_buffer((hidden_size,),?dtype=dtype,?scope="local") ????????#?對應kernel里面的aa ????????A_local?=?T.alloc_buffer((hidden_size,),?dtype=dtype,?scope="local") ????????#?對應kernel里面的bb ????????B_local?=?T.alloc_buffer((hidden_size,),?dtype=dtype,?scope="local") ????????#?對應kernel里面的cc ????????P_local?=?T.alloc_buffer((hidden_size,),?dtype=dtype,?scope="local") ????????#?迭代hidden_size?//?32次,使用T.thread_binding方法進行線程綁定,其中hidden_size?//?32是塊索引的范圍。 ????????#?這里的線程塊劃分和rwkv?kernel里面保持一致:即每個block?32個線程,一共((B=1)*C)/32個blcok ????????for?bx?in?T.thread_binding(hidden_size?//?32,?thread="blockIdx.x"): ????????????#?迭代32次,使用T.thread_binding方法進行線程綁定,其中32是線程索引的范圍。 ????????????for?tx?in?T.thread_binding(32,?thread="threadIdx.x"): ????????????????#?創建一個名為"init"的塊,用于初始化局部變量。 ????????????????with?T.block("init"): ????????????????????#?對應?const?int?_state_offset?=?_b?*?C?+?_c; ????????????????????vi?=?T.axis.S(hidden_size,?bx?*?32?+?tx) ????????????????????#?對應?float?aa?=?_aa[_state_offset]; ????????????????????A_local[vi]?=?SavedA[0,?vi] ????????????????????#?對應?float?bb?=?_bb[_state_offset]; ????????????????????B_local[vi]?=?SavedB[0,?vi] ????????????????????#?對應?float?pp?=?_pp[_state_offset]; ????????????????????P_local[vi]?=?SavedP[0,?vi] ????????????????for?j?in?range(context_length):?#?對應?for?(int?i?=?0;?i?我們可以看到mlc-llm里面的wkv forward實現基本就是用基于Relax的api將cuda函數翻譯成了TIR。注釋里面給了一些下標的推導以及每一行Relax的代碼是如何對應到原始的cuda kernel。
#?定義了一個名為_te_concat_saved_x的函數,它接受兩個參數saved_x和x,都是te.Tensor類型的張量。 #?使用TVM的te.compute函數計算一個新的張量,該張量的形狀與x相同,元素根據條件判斷進行選擇。如果i等于0, #?則選擇saved_x[0,?j]作為元素值,否則選擇x[i?-?1,?j]作為元素值。其中i和j是迭代變量。 def?_te_concat_saved_x(saved_x:?te.Tensor,?x:?te.Tensor): ????return?te.compute( ????????x.shape, ????????lambda?i,?j:?tir.if_then_else(i?==?0,?saved_x[0,?j],?x[i?-?1,?j]), ????) #?定義了一個名為_te_get_last_x的函數,它接受一個參數x,是一個te.Tensor類型的張量。 #?a.?seq_len,?hidden_size?=?x.shape:獲取x張量的形狀,其中seq_len表示序列長度,hidden_size表示隱藏大小。 #?b.?return?te.compute(...):使用TVM的te.compute函數計算一個新的張量,該張量的形狀為(1,?hidden_size), #?元素值為x[seq_len?-?1,?j],其中j是迭代變量。 def?_te_get_last_x(x:?te.Tensor): ????seq_len,?hidden_size?=?x.shape ????return?te.compute((1,?hidden_size),?lambda?_,?j:?x[seq_len?-?1,?j])這兩個函數應該對應了 https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/model.py#L455 這里代碼里面的sx = torch.cat((sx.unsqueeze(0), xx[:-1,:]))和xx[-1, :]:
@MyFunction ????def?ffn_seq(self,?x,?sx,?ln_w,?ln_b,?k_mix,?r_mix,?kw,?vw,?rw,?kmx,?krx,?kmy,?kry,?vmx,?vrx,?vmy,?vry,?rmx,?rrx,?rmy,?rry): ????????xx?=?F.layer_norm(x,?(x.shape[-1],),?weight=ln_w,?bias=ln_b) ????????sx?=?torch.cat((sx.unsqueeze(0),?xx[:-1,:])) ????????kx?=?xx?*?k_mix?+?sx?*?(1?-?k_mix) ????????rx?=?xx?*?r_mix?+?sx?*?(1?-?r_mix) ????????r?=?torch.sigmoid(gemm(rx,?rw)) ????????vx?=?torch.square(torch.relu(gemm(kx,?kw))) ????????out?=?r?*?gemm(vx,?vw) ????????return?x?+?out,?xx[-1,:]接著對Embedding函數進行解析:
#?定義了一個名為RWKV_Embedding的PyTorch模塊。 class?RWKV_Embedding(nn.Module): ????#?定義了RWKV_Embedding類的構造函數,接受三個參數num_embeddings、embedding_dim和dtype。 ????def?__init__(self,?num_embeddings,?embedding_dim,?dtype): ????????self.num_embeddings?=?num_embeddings?#?將num_embeddings賦值給類成員變量self.num_embeddings。 ????????self.embedding_dim?=?embedding_dim?#?將embedding_dim賦值給類成員變量self.embedding_dim。 ????????#?創建一個名為weight的Parameter,形狀為(num_embeddings,?embedding_dim), ????????#?數據類型為dtype,并將其賦值給類成員變量self.weight。 ????????self.weight?=?nn.Parameter( ????????????(num_embeddings,?embedding_dim),?dtype=dtype,?name="weight" ????????) ????def?forward(self,?x:?relax.Expr)?->?relax.Var: ????????#?調用op.reshape函數將輸入張量x進行reshape,將其展平為一維張量,并將結果重新賦值給x。 ????????#?nn.emit是將一個relax.Expr表達式轉化為relax.Var變量,并保存該變量。 ????????x?=?nn.emit(op.reshape(x,?shape=[-1])) ????????#?使用op.take操作從self.weight中按照索引x提取對應的嵌入向量,并返回結果。這里的axis=0表示在第一個維度上進行索引操作。 ????????return?nn.emit(op.take(self.weight,?x,?axis=0))以及LayerNorm:
#?這段代碼定義了一個名為RWKV_LayerNorm的PyTorch模塊,它實現了一個Layer?Normalization層。 class?RWKV_LayerNorm(nn.Module): ????#?定義了RWKV_LayerNorm類的構造函數,接受四個參數intermediate_size、dtype、eps和name_prefix。 ????def?__init__(self,?intermediate_size,?dtype,?eps=1e-5,?name_prefix=""): ????????super().__init__() ????????self.eps?=?eps ????????self.weight?=?nn.Parameter( ????????????(intermediate_size,),?dtype=dtype,?name=f"{name_prefix}_ln_weight" ????????) ????????self.bias?=?nn.Parameter( ????????????(intermediate_size,),?dtype=dtype,?name=f"{name_prefix}_ln_bias" ????????) ????def?forward(self,?x:?relax.Expr)?->?relax.Var: ????????#?使用op.nn.layer_norm操作對輸入張量x進行Layer?Normalization,其中使用Parameter?self.weight作為縮放參數(gamma), ????????#?使用可學習參數self.bias作為偏移參數(beta),在最后一個維度(axes=-1)上進行標準化操作, ????????#?并設置小數值修正項為self.eps。將標準化后的結果重新賦值給x。 ????????x?=?nn.emit( ????????????op.nn.layer_norm( ????????????????x, ????????????????gamma=self.weight, ????????????????beta=self.bias, ????????????????axes=-1, ????????????????epsilon=self.eps, ????????????) ????????) ????????return?x接著對FFN層做一個詳細的解析:
#?這段代碼定義了一個名為RWKV_FFN的PyTorch模塊,它實現了Feed-Forward?Network(FFN)。 class?RWKV_FFN(nn.Module): ????#?定義了RWKV_FFN類的構造函數,接受兩個參數RWKVConfig和index。 ????def?__init__(self,?config:?RWKVConfig,?index:?int)?->?None: ????????super().__init__() ????????#?將config.hidden_size賦值給類成員變量self.hidden_size,表示隱藏大小。 ????????self.hidden_size?=?config.hidden_size ????????#?將config.dtype賦值給類成員變量self.dtype,表示數據類型。 ????????self.dtype?=?config.dtype ????????#?將index賦值給類成員變 ????????self.index?=?index ????????#?建一個名為time_mix_key的可學習參數,形狀為(self.hidden_size,), ????????#?數據類型為config.dtype,命名為"ffn_{index}_time_mix_k",并將其賦值給類成員變量self.time_mix_key。 ????????self.time_mix_key?=?nn.Parameter( ????????????(self.hidden_size,),?dtype=config.dtype,?name=f"ffn_{index}_time_mix_k" ????????) ????????#?創建一個名為time_mix_receptance的可學習參數,形狀為(self.hidden_size,),數據類型為config.dtype, ????????#?命名為"ffn_{index}_time_mix_r",并將其賦值給類成員變量self.time_mix_receptance。 ????????self.time_mix_receptance?=?nn.Parameter( ????????????(self.hidden_size,),?dtype=config.dtype,?name=f"ffn_{index}_time_mix_r" ????????) ????????#?創建一個線性層,輸入大小為self.hidden_size,輸出大小為config.intermediate_size, ????????#?數據類型為config.dtype,沒有偏置項,并將其賦值給類成員變量self.key。 ????????self.key?=?Linear( ????????????self.hidden_size,?config.intermediate_size,?dtype=config.dtype,?bias=False ????????) ????????#?創建一個線性層,輸入大小為self.hidden_size,輸出大小為self.hidden_size,數據類型為config.dtype, ????????#?沒有偏置項,并將其賦值給類成員變量self.receptance。 ????????self.receptance?=?Linear( ????????????self.hidden_size,?self.hidden_size,?dtype=config.dtype,?bias=False ????????) ????????self.value?=?Linear( ????????????config.intermediate_size,?self.hidden_size,?dtype=config.dtype,?bias=False ????????) ????def?forward(self,?x:?Expr,?state:?Expr)?->?Expr: ????????#?計算偏移量,用于在state中獲取對應的保存狀態。 ????????offset?=?self.index?*?5?+?State.FFN_X ????????#?獲取x的shape[0]表示上下文長度。 ????????context_length?=?x.struct_info.shape[0] ????????#?獲取隱藏層大小。 ????????hidden_size?=?self.hidden_size ????????#?調用_load_state函數從state中加載保存的狀態state[offset],并將結果賦值給saved_x。 ????????saved_x?=?_load_state(state[offset],?hidden_size,?self.dtype) ????????#?如果上下文長度不為1,則執行下面的操作。 ????????if?not?is_one(context_length): ????????????#?調用nn.emit_te函數,將saved_x和x作為參數傳遞給 ????????????#?_te_concat_saved_x函數進行計算,并將結果重新賦值給saved_x。 ????????????#?類似于transformer?里面的KV?Cache的,但是這里的concat是緯度不變的 ????????????#?對應?sx?=?torch.cat((sx.unsqueeze(0),?xx[:-1,:]))?這行代碼 ????????????saved_x?=?nn.emit_te(_te_concat_saved_x,?saved_x,?x) ????????#?創建一個全為1的張量,形狀為(hidden_size,),數據類型為self.dtype,并將其賦值給ones。 ????????ones?=?nn.emit(relax.op.ones((hidden_size,),?self.dtype)) ????????#?計算xk,根據時間混合參數self.time_mix_key和保存的狀態saved_x,使用加權求和的方式得到。 ????????#?其中,x和saved_x分別乘以self.time_mix_key和(ones?-?self.time_mix_key),然后相加。將計算結果賦值給xk。 ????????#?對應?kx?=?xx?*?k_mix?+?sx?*?(1?-?k_mix)?這行代碼 ????????xk?=?nn.emit(x?*?self.time_mix_key?+?saved_x?*?(ones?-?self.time_mix_key)) ????????#?計算xr,根據時間混合參數self.time_mix_receptance和保存的狀態saved_x,使用加權求和的方式得到。 ????????#?其中,x和saved_x分別乘以self.time_mix_receptance和(ones?-?self.time_mix_receptance),然后相加。 ????????#?將計算結果賦值給xr。 ????????#?對應?rx?=?xx?*?r_mix?+?sx?*?(1?-?r_mix) ????????xr?=?nn.emit( ????????????x?*?self.time_mix_receptance?+?saved_x?*?(ones?-?self.time_mix_receptance) ????????) ????????#?#?如果上下文長度不為1,則執行下面的操作。 ????????if?not?is_one(context_length): ????????????#?調用nn.emit_te函數,使用_te_get_last_x函數從x中獲取最后一個token對應的tensor,并將結果重新賦值給x。 ????????????#?對應?xx[-1,:] ????????????x?=?nn.emit_te(_te_get_last_x,?x) ????????#?斷言x的結構信息(shape)的第一個維度為1。 ????????assert?is_one(x.struct_info.shape[0]) ????????#?調用_store_state函數,將x保存到state[offset]中,并將結果重新賦值給saved_x。 ????????#?對應:https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/model.py#L921 ????????saved_x?=?_store_state(state[offset],?x) ????????#?將xr作為輸入,經過sigmoid激活函數計算得到r。對應:r?=?torch.sigmoid(gemm(rx,?rw)) ????????r?=?nn.emit(op.sigmoid(self.receptance(xr))) ????????#?對應?vx?=?torch.square(torch.relu(gemm(kx,?kw))) ????????xv?=?nn.emit(op.square(op.nn.relu(self.key(xk)))) ????????return?nn.emit(r?*?self.value(xv)),?[saved_x]接下來對Attention部分的實現進行解析,注意這部分對應的代碼在 https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/model.py#L728-L747 。貼一下python代碼防止看錯位置產生疑問:
if?os.environ["RWKV_CUDA_ON"]?==?'1': ????????@MyFunction ????????def?cuda_att_seq(self,?x,?sx,?aa,?bb,?pp,?ln_w,?ln_b,?k_mix,?v_mix,?r_mix,?t_decay,?t_first,?kw,?vw,?rw,?ow,?kmx,?krx,?kmy,?kry,?vmx,?vrx,?vmy,?vry,?rmx,?rrx,?rmy,?rry,?omx,?orx,?omy,?ory): ????????????T,?C?=?x.shape ????????????xx?=?F.layer_norm(x,?(C,),?weight=ln_w,?bias=ln_b) ????????????sx?=?torch.cat((sx.unsqueeze(0),?xx[:-1,:])) ????????????kx?=?xx?*?k_mix?+?sx?*?(1?-?k_mix) ????????????vx?=?xx?*?v_mix?+?sx?*?(1?-?v_mix) ????????????rx?=?xx?*?r_mix?+?sx?*?(1?-?r_mix) ????????????r?=?torch.sigmoid(gemm(rx,?rw)) ????????????k?=?gemm(kx,?kw,?output_dtype=torch.float32) ????????????v?=?gemm(vx,?vw,?output_dtype=torch.float32) ????????????y,?aa,?bb,?pp?=?cuda_wkv(T,?aa.shape[0],?t_decay,?t_first,?k,?v,?aa,?bb,?pp) ???????????? ????????????out?=?gemm(r?*?y.to(x.dtype),?ow) ????????????return?x?+?out,?xx[-1,:],?aa,?bb,?pp對應mlc-llm RWKV Attention的代碼解析為:
#?實現RWKV?Attention,對應?https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/model.py#L479 class?RWKV_Attention(nn.Module): ????#?初始化函數,接受一個config對象和一個整數index作為參數。其中config是一個RWKVConfig類型的對象,index表示當前層的索引。 ????def?__init__(self,?config:?RWKVConfig,?index:?int)?->?None: ????????super().__init__() ????????self.index?=?index ????????self.dtype?=?config.dtype ????????self.hidden_size?=?config.hidden_size ????????#?創建一些可學習的參數,如time_decay、time_first、time_mix_key等,這些參數會在模型的前向傳播中使用。 ????????self.time_decay?=?nn.Parameter( ????????????(self.hidden_size,),?dtype="float32",?name=f"att_{index}_time_decay" ????????) ????????self.time_first?=?nn.Parameter( ????????????(self.hidden_size,),?dtype="float32",?name=f"att_{index}_time_first" ????????) ????????self.time_mix_key?=?nn.Parameter( ????????????(self.hidden_size,),?dtype=config.dtype,?name=f"att_{index}_time_mix_k" ????????) ????????self.time_mix_value?=?nn.Parameter( ????????????(self.hidden_size,),?dtype=config.dtype,?name=f"att_{index}_time_mix_v" ????????) ????????self.time_mix_receptance?=?nn.Parameter( ????????????(self.hidden_size,),?dtype=config.dtype,?name=f"att_{index}_time_mix_r" ????????) ????????#?前向傳播用到的線性層 ????????self.key?=?Linear( ????????????self.hidden_size,?self.hidden_size,?dtype=config.dtype,?bias=False ????????) ????????self.value?=?Linear( ????????????self.hidden_size,?self.hidden_size,?dtype=config.dtype,?bias=False ????????) ????????self.receptance?=?Linear( ????????????self.hidden_size,?self.hidden_size,?dtype=config.dtype,?bias=False ????????) ????????self.output?=?Linear( ????????????self.hidden_size,?self.hidden_size,?dtype=config.dtype,?bias=False ????????) ????#?前向傳播函數,接受輸入張量x和狀態張量state作為參數,并返回輸出張量 ????def?forward(self,?x:?Expr,?state:?Expr)?->?Expr: ????????#?Load?current?state ????????#?定義了一些局部變量,如ones、index、hidden_size、context_length等。 ????????ones?=?nn.emit(relax.op.ones((self.hidden_size,),?self.dtype)) ????????index?=?self.index ????????hidden_size?=?self.hidden_size ????????context_length?=?x.struct_info.shape[0] ????????bb?=?relax.BlockBuilder.current() ????????#?_load_state函數從state中加載保存的狀態,賦值給saved_a、saved_b、saved_p和saved_x。 ????????saved_a?=?_load_state(state[index?*?5?+?State.ATT_A],?hidden_size,?"float32") ????????saved_b?=?_load_state(state[index?*?5?+?State.ATT_B],?hidden_size,?"float32") ????????saved_p?=?_load_state(state[index?*?5?+?State.ATT_P],?hidden_size,?"float32") ????????saved_x?=?_load_state(state[index?*?5?+?State.ATT_X],?hidden_size,?self.dtype) ???????? ????????#?調用nn.emit_te函數,將saved_x和x作為參數傳遞給 ????????#?_te_concat_saved_x函數進行計算,并將結果重新賦值給saved_x。 ????????#?對應?sx?=?torch.cat((sx.unsqueeze(0),?xx[:-1,:])) ????????if?not?is_one(context_length): ????????????saved_x?=?nn.emit_te(_te_concat_saved_x,?saved_x,?x) ????????#?對應?kx?=?xx?*?k_mix?+?sx?*?(1?-?k_mix) ????????xk?=?nn.emit(x?*?self.time_mix_key?+?saved_x?*?(ones?-?self.time_mix_key)) ????????#?對應?vx?=?xx?*?v_mix?+?sx?*?(1?-?v_mix) ????????xv?=?nn.emit(x?*?self.time_mix_value?+?saved_x?*?(ones?-?self.time_mix_value)) ????????#?對應?rx?=?xx?*?r_mix?+?sx?*?(1?-?r_mix) ????????xr?=?nn.emit( ????????????x?*?self.time_mix_receptance?+?saved_x?*?(ones?-?self.time_mix_receptance) ????????) ????????#?對應?r?=?torch.sigmoid(gemm(rx,?rw)) ????????r?=?nn.emit(op.sigmoid(self.receptance(xr))) ????????#?對應?k?=?gemm(kx,?kw,?output_dtype=torch.float32) ????????k?=?nn.emit(op.astype(self.key(xk),?"float32")) ????????#?對應?v?=?gemm(vx,?vw,?output_dtype=torch.float32) ????????v?=?nn.emit(op.astype(self.value(xv),?"float32")) ????????#?這部分對應?y,?aa,?bb,?pp?=?cuda_wkv(T,?aa.shape[0],?t_decay,?t_first,?k,?v,?aa,?bb,?pp) ????????#?這里的?create_wkv_func?在上面已經解析了 ????????gv?=?bb.add_func(create_wkv_func(hidden_size,?"float32",?self.dtype),?"wkv") ????????ret?=?nn.emit( ????????????relax.call_tir( ????????????????gv, ????????????????[k,?v,?self.time_decay,?self.time_first,?saved_a,?saved_b,?saved_p], ????????????????[ ????????????????????R.Tensor((context_length,?hidden_size),?self.dtype),?#?對應wkv ????????????????????R.Tensor((1,?hidden_size),?"float32"),?#?對應out_a ????????????????????R.Tensor((1,?hidden_size),?"float32"),?#?對應out_b ????????????????????R.Tensor((1,?hidden_size),?"float32"),?#?對應out_p ????????????????], ????????????) ????????) ????????if?not?is_one(context_length): ????????????#?對應?xx[-1,:] ????????????x?=?nn.emit_te(_te_get_last_x,?x) ????????assert?is_one(x.struct_info.shape[0]) ????????saved_x?=?_store_state(state[self.index?*?5?+?State.ATT_X],?x) ????????saved_a?=?_store_state(state[self.index?*?5?+?State.ATT_A],?ret[1]) ????????saved_b?=?_store_state(state[self.index?*?5?+?State.ATT_B],?ret[2]) ????????saved_p?=?_store_state(state[self.index?*?5?+?State.ATT_P],?ret[3]) ????????#?需要注意一下,python代碼里面的?return?x?+?out,?xx[-1,:],?aa,?bb,?pp ????????#?這里的?x?+?out被放在attention外面做了,因為這里的x已經是被修改之后好的結果而不是原始的x ????????return?nn.emit(self.output(r?*?ret[0])),?[ ????????????saved_x, ????????????saved_a, ????????????saved_b, ????????????saved_p, ????????]接著解析一下RWKVLayer的實現,請注意下面的最后一行代碼的解釋:
class?RWKVLayer(nn.Module): ????#?初始化函數,接受一個config對象和一個整數index作為參數。其中config是一個RWKVConfig類型的對象,index表示層的索引。 ????def?__init__(self,?config:?RWKVConfig,?index:?int)?->?None: ????????super().__init__() ????????#?如果index為0,創建一個RWKV_LayerNorm對象pre_ln,用于對輸入進行Layer?Normalization操作。 ????????if?index?==?0: ????????????self.pre_ln?=?RWKV_LayerNorm( ????????????????config.hidden_size, ????????????????config.dtype, ????????????????eps=config.layer_norm_epsilon, ????????????????name_prefix="pre_ln", ????????????) ????????#?創建兩個RWKV_LayerNorm對象,分別命名為ln1和ln2, ????????#?用于對注意力機制和前饋神經網絡的輸出進行Layer?Normalization操作。 ????????self.ln1?=?RWKV_LayerNorm( ????????????config.hidden_size, ????????????config.dtype, ????????????eps=config.layer_norm_epsilon, ????????????name_prefix=f"att_{index}", ????????) ????????self.ln2?=?RWKV_LayerNorm( ????????????config.hidden_size, ????????????config.dtype, ????????????eps=config.layer_norm_epsilon, ????????????name_prefix=f"ffn_{index}", ????????) ????????#?創建一個RWKV_Attention對象attention,用于實現注意力機制。 ????????self.attention?=?RWKV_Attention(config,?index) ????????#?創建一個RWKV_FFN對象feed_forward,用于實現前饋神經網絡。 ????????self.feed_forward?=?RWKV_FFN(config,?index) ????????self.rescale_every?=?config.rescale_every ????????self.dtype?=?config.dtype ????????self.index?=?index ????#?前向傳播函數,接受輸入張量x和狀態張量state作為參數,并返回輸出張量和更新后的狀態列表。 ????def?forward(self,?x:?Expr,?state:?Expr)?->?Tuple[Expr,?List[Expr]]: ????????#?如果index為0,則將輸入張量x傳入pre_ln進行Layer?Normalization操作。 ????????if?self.index?==?0: ????????????x?=?self.pre_ln(x) ????????#?將經過ln1的輸入張量x和狀態張量state傳入attention進行計算,得到注意力機制的輸出att和更新后的狀態列表att_state。 ????????att,?att_state?=?self.attention(self.ln1(x),?state) ????????#?將輸入張量x和注意力機制的輸出att相加,并將結果賦值給x。 ????????x?=?nn.emit(x?+?att) ????????#?將經過ln2的輸入張量x和狀態張量state傳入feed_forward進行計算,得到前饋神經網絡的輸出ffn和更新后的狀態列表ffn_state。 ????????ffn,?ffn_state?=?self.feed_forward(self.ln2(x),?state) ????????#?將輸入張量x和前饋神經網絡的輸出ffn相加,并將結果賦值給x。 ????????x?=?nn.emit(x?+?ffn) ????????#?如果滿足self.rescale_every?>?0且(self.index?+?1)?%?self.rescale_every?==?0,則對輸入張量x進行縮放操作。 ????????if?self.rescale_every?>?0?and?(self.index?+?1)?%?self.rescale_every?==?0: ????????????x?=?nn.emit(x?/?relax.const(2,?dtype=self.dtype)) ????????#?返回輸出張量x和注意力機制和前饋神經網絡的更新后的狀態列表的拼接。 ????????return?x,?att_state?+?ffn_state注意這里的attn_state是[saved_x,saved_a,saved_b,saved_p,] ,然后ffn_state是[saved_x],注意這兩個x是不一樣的,這5個狀態也和本節開頭的class State的成員定義一致。
接下來對RWKV模型定義進行了解析:
#?該代碼是一個自定義的PyTorch模型類RWKVModel,繼承自nn.Module class?RWKVModel(nn.Module): ????#?初始化函數,接受一個config對象作為參數。其中config是一個RWKVConfig類型的對象。 ????def?__init__(self,?config:?RWKVConfig)?->?None: ????????super().__init__() ????????#?創建一個RWKV_Embedding對象embeddings,用于實現輸入的嵌入操作。 ????????self.embeddings?=?RWKV_Embedding( ????????????num_embeddings=config.vocab_size, ????????????embedding_dim=config.hidden_size, ????????????dtype=config.dtype, ????????) ????????#?創建一個ModuleList對象blocks,其中包含了config.num_hidden_layers個RWKVLayer對象, ????????#?每個對象的索引從0到config.num_hidden_layers-1。 ????????self.blocks?=?ModuleList( ????????????[RWKVLayer(config,?i)?for?i?in?range(config.num_hidden_layers)] ????????) ????????#?創建一個RWKV_LayerNorm對象ln_out,用于對輸出進行Layer?Normalization操作。 ????????self.ln_out?=?RWKV_LayerNorm( ????????????config.hidden_size, ????????????config.dtype, ????????????eps=config.layer_norm_epsilon, ????????????name_prefix="out_ln", ????????) ????????self.hidden_size?=?config.hidden_size ????????self.dtype?=?config.dtype ????#?前向傳播函數,接受輸入張量input_ids和狀態張量state作為參數,并返回輸出張量和更新后的狀態列表。 ????def?forward(self,?input_ids:?Expr,?state:?Expr)?->?Tuple[Expr,?List[Expr]]: ????????#?將輸入張量input_ids傳入embeddings進行嵌入操作,得到隱藏狀態張量hidden_states。 ????????hidden_states?=?self.embeddings(input_ids) ????????#?創建一個空列表states,用于存儲每個RWKVLayer對象的更新后的狀態列表。 ????????states?=?[] ????????#?遍歷blocks中的每個RWKVLayer對象,將隱藏狀態張量hidden_states和狀態張量state傳入 ????????#?每個RWKVLayer對象的前向傳播函數進行計算,得到更新后的隱藏狀態張量和更新后的狀態列表, ????????#?并將更新后的狀態列表添加到states中。 ????????for?_,?layer?in?enumerate(self.blocks): ????????????hidden_states,?layer_states?=?layer(hidden_states,?state) ????????????states?+=?layer_states ????????#?獲取隱藏狀態張量的上下文長度context_length。 ????????context_length?=?hidden_states.struct_info.shape[0] ????????#?如果context_length不為1,則調用_te_get_last_x函數獲取最后一個token對應的張量。 ????????if?not?is_one(context_length): ????????????hidden_states?=?nn.emit_te(_te_get_last_x,?hidden_states) ????????#?將隱藏狀態張量傳入ln_out進行Layer?Normalization操作。 ????????hidden_states?=?self.ln_out(hidden_states) ????????#?返回輸出隱藏狀態張量和所有RWKVLayer對象的更新后的狀態列表。 ????????return?hidden_states,?states #?該代碼是一個自定義的PyTorch模型類RWKVForCausalLM,繼承自nn.Module。 class?RWKVForCausalLM(nn.Module): ????#?初始化函數,接受一個config對象作為參數。其中config是一個RWKVConfig類型的對象。 ????def?__init__(self,?config:?RWKVConfig): ????????#?創建一個RWKVModel對象rwkv,用于實現序列模型的計算。 ????????self.rwkv?=?RWKVModel(config) ????????#?創建一個Linear對象head,用于將隱藏狀態映射到詞匯表大小的輸出空間。 ????????self.head?=?Linear( ????????????config.hidden_size,?config.vocab_size,?dtype=config.dtype,?bias=False ????????) ????????self.vocab_size?=?config.vocab_size ????????############?End?############ ????#?前向傳播函數,接受輸入張量input_ids和狀態張量state作為參數,并返回預測的logits和更新后的kv?cache。 ????def?forward( ????????self, ????????input_ids:?relax.Expr, ????????state:?relax.Expr, ????): ????????#?將輸入張量input_ids和狀態張量state傳入rwkv對象的前向傳播函數進行計算, ????????#?得到更新后的隱藏狀態張量hidden_states和key-value緩存key_value_cache。 ????????hidden_states,?key_value_cache?=?self.rwkv(input_ids,?state) ????????#?將隱藏狀態張量hidden_states傳入head進行線性映射操作,得到logits。 ????????logits?=?nn.emit(self.head(hidden_states)) ????????#?對logits進行形狀重塑,將其reshape為形狀為(1,?1,?self.vocab_size)的張量。 ????????logits?=?nn.emit(op.reshape(logits,?(1,?1,?self.vocab_size))) ????????#?如果logits的數據類型不是float32,則將其轉換為float32類型。 ????????if?logits.struct_info.dtype?!=?"float32": ????????????logits?=?nn.emit(relax.op.astype(logits,?"float32")) ????????return?logits,?key_value_cache解下是一個根據參數的名字確定量化參數類型的函數:
#?該代碼定義了一個函數get_param_quant_kind,用于根據參數名稱和參數信息確定參數的量化類型。 def?get_param_quant_kind( ????name:?str,?param_info:?relax.TensorStructInfo )?->?ParamQuantKind: ????#?如果參數名稱以"embeddings.weight"結尾,返回ParamQuantKind.embedding_table表示該參數是嵌入表的權重。 ????if?name.endswith("embeddings.weight"): ????????return?ParamQuantKind.embedding_table ????#?如果參數名稱為"head.weight",返回ParamQuantKind.final_fc_weight表示該參數是最后一個全連接層的權重。 ????elif?name?==?"head.weight": ????????return?ParamQuantKind.final_fc_weight ????#?如果參數的維度為2且名稱以".weight"結尾,返回ParamQuantKind.linear_weight表示該參數是線性層的權重。 ????elif?param_info.ndim?==?2?and?name.endswith(".weight"): ????????return?ParamQuantKind.linear_weight ????else: ????????return?ParamQuantKind.others上面已經完成了RWKV模型的定義,接下來是定義幾個相關的TIR函數并定義一個最終的TIR模型獲取函數。這里對創建prefill和decode的create_func函數以及最終的TIR模型獲取函數get_model進行解析:
由于字數被公眾號限制了,請在知乎文章查看這部分,https://zhuanlan.zhihu.com/p/658354795
自此,我們基本就有了搭建RWKV模型的全部流程,說白了就是用TVM的Relax語言手動一對一的把PyTorch實現翻譯過去。
0x3. Transform舉例
在mlc-llm有一些圖層的優化,在 https://github.com/BBuf/mlc-llm-code-analysis/tree/main/mlc_llm/transform 這個文件里面,我們對其中的一些優化Pass做一下解析。
0x3.1 rewrite attention
代碼如下:
#?導入了TVM的relax模塊中的一些函數和類,以及TVM的script模塊中的relax別名。 from?tvm.relax.dpl?import?PatternContext,?is_const,?is_op,?rewrite_call,?wildcard from?tvm.script?import?relax?as?R #?定義了一個名為rewrite_attention的函數,接收一個參數f。 def?rewrite_attention(f): ????#?使用wildcard()創建了三個通配符,分別賦值給Q、K和V。 ????Q?=?wildcard() ????K?=?wildcard() ????V?=?wildcard() ????#?使用is_op()函數創建了三個操作模式,分別對應Q、K和V的維度重排操作,并將結果分別賦值給Q_BNSH、K_BNSH和V_BNSH。 ????Q_BNSH?=?is_op("relax.permute_dims")(Q) ????K_BNSH?=?is_op("relax.permute_dims")(K) ????V_BNSH?=?is_op("relax.permute_dims")(V) ????#?使用is_op()函數創建了一個操作模式,對應K_BNSH的維度重排操作,并將結果賦值給K_BNSH_T。 ????K_BNSH_T?=?is_op("relax.permute_dims")(K_BNSH) ????#?使用is_op()函數創建了一系列操作模式,對應矩陣乘法、除法、最大值、最小值、softmax以及另一個矩陣乘法操作。 ????#?這些操作模式(Attention)根據之前定義的通配符和常數匹配不同的計算圖節點。 ????matmul1?=?is_op("relax.matmul")(Q_BNSH,?K_BNSH_T) ????divide?=?is_op("relax.divide")(matmul1,?is_const()) ????max?=?is_op("relax.maximum")(divide,?is_const()) ????min?=?is_op("relax.minimum")(max,?wildcard()) ????softmax?=?is_op("relax.nn.softmax")(is_op("relax.astype")(min)) ????matmul2?=?is_op("relax.matmul")(is_op("relax.astype")(softmax),?V_BNSH) ????#?使用is_op()函數創建了一個操作模式,對應matmul2的維度重排操作,并將結果賦值給pattern。 ????pattern?=?is_op("relax.permute_dims")(matmul2) ????#?定義了一個名為callback的回調函數,接收兩個參數_和matchings。 ????#?該回調函數使用R.nn.attention函數構建一個新的計算圖節點,并使用matchings字典中的匹配結果來填充該節點的參數。 ????def?callback(_,?matchings): ????????return?R.nn.attention( ????????????matchings[Q],?matchings[K],?matchings[V],?causal_mask="BottomRight" ????????) ????#?使用rewrite_call函數將pattern、callback和輸入的計算圖f傳遞給它,以便在計算圖中應用模式匹配和重寫。 ????#?最后,將重寫后的計算圖返回。 ????return?rewrite_call(pattern,?callback,?f)雖然沒有完全看懂這里的操作比如max和min的含義,但是從后面的callback_可以猜測出這里的Pass就是把打散的Self Attention模塊融合為一個relax.nn.attention操作。在cuda后端,如果支持了cutlass,那么relax.nn.attention操作就對應了Flash Attention。
0x3.2 Transpose MatMul
代碼實現解析如下:
#?這段代碼定義了一個名為TransposeMatmulCodeGenerator的類,該類繼承自relax.PyExprMutator。 #?通過@relax.expr_functor.mutator裝飾器將該類聲明為一個表達式重寫器。 @relax.expr_functor.mutator class?TransposeMatmulCodeGenerator(relax.PyExprMutator): ????def?__init__(self,?mod): ????????super().__init__(mod) ????@staticmethod ????def?pattern(): ????????#?定義了靜態方法pattern(),該方法返回一個描述模式的元組。 ????????#?通過使用通配符(wildcard())和操作模式(is_op())來匹配計算圖中的特定模式。 ????????#?在這個例子中,模式匹配了一個矩陣乘法操作中矩陣w的維度重排操作,并將匹配的結果保存在字典annotations中。 ????????w?=?wildcard() ????????x?=?wildcard() ????????wT?=?is_op("relax.permute_dims")(w) ????????o?=?is_op("relax.matmul")(x,?wT) ????????annotations?=?{"o":?o,?"w":?w,?"x":?x,?"wT":?wT} ????????#?定義了內部函數_check(),用于檢查模式匹配的結果是否滿足特定的條件。 ????????#?在這個例子中,檢查了維度重排操作的維度數和軸的順序是否正確。 ????????def?_check(context:?relax.transform.PatternCheckContext)?->?bool: ????????????transpose_call?=?context.annotated_expr["wT"] ????????????ndim?=?transpose_call.args[0].struct_info.ndim ????????????if?ndim?==?-1: ????????????????return?False ????????????if?ndim?==?2?and?transpose_call.attrs.axes?is?None: ????????????????return?True ????????????axes?=?list(range(ndim)) ????????????axes[-1],?axes[-2]?=?axes[-2],?axes[-1] ????????????return?list(transpose_call.attrs.axes)?==?axes ????????#?將匹配的計算圖節點、注解和檢查函數作為元組返回。 ????????return?o,?annotations,?_check ????#?重寫了父類的visit_call_()方法,用于處理特定類型的計算圖節點。 ????def?visit_call_(self,?call:?relax.Call)?->?relax.Expr: ????????#?定義了一個變量out_dtype,用于保存輸出的數據類型。 ????????out_dtype?=?None ????????#?定義了一個內部函數te_transposed_matmul(),該函數實現了矩陣乘法的計算邏輯。 ????????def?te_transposed_matmul(a:?te.Tensor,?b:?te.Tensor)?->?te.Tensor: ????????????nonlocal?out_dtype ????????????#?將輸入張量?a?和?b?的形狀轉換為列表形式,分別保存在變量?a_shape?和?b_shape?中。 ????????????a_shape?=?list(a.shape) ????????????b_shape?=?list(b.shape) ????????????#?定義了兩個布爾變量?a_prepended?和?b_appended,用于標記是否在相應的形狀的前面或后面添加了維度。 ????????????a_prepended?=?False ????????????b_appended?=?False ????????????#?如果輸入張量?a?的形狀為一維,則在其前面添加一個維度,將其形狀修改為?(1,?original_shape)。 ????????????#?同樣地,如果輸入張量?b?的形狀為一維,則在其后面添加一個維度,將其形狀修改為?(original_shape,?1)。 ????????????if?len(a_shape)?==?1: ????????????????a_prepended?=?True ????????????????a_shape.insert(0,?1) ????????????if?len(b_shape)?==?1: ????????????????b_appended?=?True ????????????????b_shape.append(1) ????????????#?比較?a_shape?和?b_shape?的長度,將結果保存在布爾變量?is_a_larger?中。 ????????????#?offset?表示兩個形狀長度之差,用于后續處理。 ????????????is_a_larger?=?len(a_shape)?>?len(b_shape) ????????????offset?=?( ????????????????len(a_shape)?-?len(b_shape) ????????????????if?is_a_larger ????????????????else?len(b_shape)?-?len(a_shape) ????????????) ????????????#?創建兩個?relax.Var?對象?a_relax?和?bT_relax,用于表示張量?a?和轉置后的張量?bT?的結構信息。 ????????????#?a_relax?的形狀和?a?的形狀相同,bT_relax?的形狀是?b?的形狀經過維度互換后的結果。 ????????????a_relax?=?relax.Var("a",?relax.TensorStructInfo(a.shape)) ????????????bT_shape?=?list(b.shape) ????????????bT_shape[-1],?bT_shape[-2]?=?bT_shape[-2],?bT_shape[-1] ????????????bT_relax?=?relax.Var("b",?relax.TensorStructInfo(bT_shape)) ????????????#?使用?relax.op.matmul()?方法對?a_relax?和?bT_relax?進行矩陣乘法運算。 ????????????#?然后,通過?self.builder_.normalize()?方法對結果進行歸一化處理,并獲取最終的輸出形狀。 ????????????output_shape?=?self.builder_.normalize( ????????????????relax.op.matmul(a_relax,?bT_relax) ????????????).struct_info.shape ????????????#?該函數接受可變數量的空間索引參數?idx_spatial, ????????????def?matmul_compute(*idx_spatial): ????????????????#?并定義了一個名為?k?的規約軸(reduce?axis),其范圍為?0?到?a_shape[-1]。 ????????????????k?=?te.reduce_axis((0,?a_shape[-1]),?name="k") ????????????????#?定義了一個名為?multiply_compute?的內部函數,用于計算乘法操作時的索引。 ????????????????def?multiply_compute(idx_reduce): ????????????????????a_indices?=?[] ????????????????????b_indices?=?[] ????????????????????#?根據?is_a_larger?的值,將?idx_spatial?中的索引分配給?a_indices?或?b_indices,用于處理形狀長度差異的維度。 ????????????????????for?i?in?range(offset): ????????????????????????if?is_a_larger: ????????????????????????????a_indices.append(idx_spatial[i]) ????????????????????????else: ????????????????????????????b_indices.append(idx_spatial[i]) ????????????????????for?i?in?range( ????????????????????????offset,?len(output_shape)?-?(2?-?a_prepended?-?b_appended) ????????????????????): ????????????????????????#?根據維度的相等性,將適當的索引添加到?a_indices?和?b_indices?中。 ????????????????????????#?如果維度不相等或無法確定是否相等,則將索引設為?0?或保持不變。 ????????????????????????a_dim?=?a_shape[i?if?is_a_larger?else?i?-?offset] ????????????????????????b_dim?=?b_shape[i?if?not?is_a_larger?else?i?-?offset] ????????????????????????dim_equal?=?a_dim?==?b_dim ????????????????????????if?not?isinstance(dim_equal,?tir.IntImm)?or?dim_equal?==?0: ????????????????????????????a_dim_is_one?=?isinstance(a_dim,?tir.IntImm)?and?a_dim?==?1 ????????????????????????????b_dim_is_one?=?isinstance(b_dim,?tir.IntImm)?and?b_dim?==?1 ????????????????????????????a_indices.append(0?if?a_dim_is_one?else?idx_spatial[i]) ????????????????????????????b_indices.append(0?if?b_dim_is_one?else?idx_spatial[i]) ????????????????????????else: ????????????????????????????a_indices.append(idx_spatial[i]) ????????????????????????????b_indices.append(idx_spatial[i]) ????????????????????#?在乘法操作的索引中添加規約軸?idx_reduce,并根據?a_prepended?和?b_appended?的值, ????????????????????#?將適當的索引添加到?a_indices?和?b_indices?中。 ????????????????????if?not?a_prepended: ????????????????????????a_indices.append(idx_spatial[-2?+?b_appended]) ????????????????????a_indices.append(idx_reduce) ????????????????????if?not?b_appended: ????????????????????????b_indices.append(idx_spatial[-1]) ????????????????????b_indices.append(idx_reduce) ????????????????????#?根據?out_dtype?的值,選擇是否進行數據類型轉換,并返回乘法操作的結果。 ????????????????????dtype?=?out_dtype ????????????????????if?dtype?!=?"": ????????????????????????return?a(*a_indices).astype(dtype)?*?b(*b_indices).astype(dtype) ????????????????????return?a(*a_indices)?*?b(*b_indices) ????????????????#?在縮減軸?k?上對?multiply_compute?的結果進行求和操作。 ????????????????return?te.sum(multiply_compute(k),?axis=k) ????????????#?使用?te.compute()?函數計算最終的輸出,其中使用一個?lambda?函數將輸入索引傳遞給?matmul_compute?函數, ????????????#?并將結果命名為?"NT_matmul"。整個計算過程將根據?output_shape?進行執行。 ????????????return?te.compute( ????????????????output_shape, ????????????????lambda?*idx:?matmul_compute(*idx),??#?pylint:?disable=unnecessary-lambda ????????????????name="NT_matmul", ????????????) ????????#?首先,檢查函數調用的操作符?call.op?是否是?relax.GlobalVar?類型。如果是,獲取與該操作符對應的函數對象, ????????#?并檢查函數的屬性中是否包含鍵?"Composite",且其值為?"transpose_matmul_fuse"。 ????????if?isinstance(call.op,?relax.GlobalVar): ????????????function?=?self.builder_.get()[call.op] ????????????if?( ????????????????"Composite"?in?function.attrs ????????????????and?function.attrs["Composite"]?==?"transpose_matmul_fuse" ????????????): ????????????????#?將函數的返回類型?function.ret_struct_info.dtype?賦值給變量?out_dtype ????????????????out_dtype?=?function.ret_struct_info.dtype ????????????????#?然后調用?self.builder_.call_te()?方法,傳遞?te_transposed_matmul?函數作為參數, ????????????????#?以及調用的參數?call.args[1]?和?call.args[0],并指定?primfunc_name_hint?為?"NT_matmul"。 ????????????????return?self.builder_.call_te( ????????????????????te_transposed_matmul, ????????????????????call.args[1], ????????????????????call.args[0], ????????????????????primfunc_name_hint="NT_matmul", ????????????????) ????????return?super().visit_call_(call) #?使用?@tvm.transform.module_pass?裝飾器定義了一個名為?FuseTransposeMatmul?的類, #?并指定了優化級別?opt_level=0?和?pass?的名稱為?"FuseTransposeMatmul"。 @tvm.transform.module_pass(opt_level=0,?name="FuseTransposeMatmul") class?FuseTransposeMatmul: ????#?定義了?transform_module?方法,接受一個名為?mod?的?IRModule?對象和 ????#?tvm.transform.PassContext?對象作為參數,并返回一個?IRModule?對象。 ????def?transform_module( ????????self,?mod:?IRModule,?ctx:?tvm.transform.PassContext ????)?->?IRModule: ????????#?通過調用?relax.transform.FuseOpsByPattern?并傳遞一個包含單個模式元組的列表, ????????#?對模塊?mod?進行融合的轉置矩陣乘法操作。 ????????mod?=?relax.transform.FuseOpsByPattern( ????????????[("transpose_matmul_fuse",?*TransposeMatmulCodeGenerator.pattern())] ????????)(mod) ????????#?創建一個名為?transpose_matmul_codegen?的?TransposeMatmulCodeGenerator?對象, ????????#?并對模塊中的每個函數進行遍歷。如果函數是?relax.Function?類型,則調用?transpose_matmul_codegen.visit_expr? ????????#?方法對函數進行轉置矩陣乘法代碼生成,并通過?transpose_matmul_codegen.builder_.update_func?方法更新函數。 ????????transpose_matmul_codegen?=?TransposeMatmulCodeGenerator(mod) ????????for?gv?in?mod.functions: ????????????func?=?mod[gv] ????????????if?not?isinstance(func,?relax.Function): ????????????????continue ????????????func?=?transpose_matmul_codegen.visit_expr(func) ????????????transpose_matmul_codegen.builder_.update_func(gv,?func) ????????#?返回轉置矩陣乘法代碼生成器的?builder?對象中的模塊。 ????????return?transpose_matmul_codegen.builder_.get()?
?
這個Pass將Transpose算子和一個MatMul算子替換為一個TE表達式的實現來達到融合算子的目的。
除了上面2種Pass,MLC-LLM還有不少的圖變換Pass,這篇文章就不一一去解析了,大多數優化的目的都是匹配某種Pattern然后用更優秀的算子去完成計算。
量化策略這一塊就不在這篇文章解析了。
0x4. MLC-LLM優缺點個人評價和期待
0x4.1 優點
Tune Free。mlc-llm不需要用TVM的AutoTVM/Ansor等等程序去執行算子搜索過程,對跨平臺部署是比原始的TVM搭建的模型更清真的。
TIR的語法很大程度靠近了PyTorch的API,使得用戶在模型搭建部分不會很困難。
文檔寫得不錯,跟隨教程基本可以完成大多數平臺的模型部署,并且單Batch下的吞吐和延遲表現都是不錯的。
0x4.2 缺點
不支持從onnx或者huggingface模型直接轉換出TIR,手工實現模型的時候需要相當多的先驗知識,比如在上面的RWKV模型中如果有自定義的cuda kernel,那么這個模型的實現可能只能全權委托給mlc-ai社區的核心開發人員了。
KV Cache開的是max_sequence_length這么長,顯然會有顯存的浪費,Serving的時候極限情況下可以服務的用戶數量應該比VLLM/TGI等要小?
CUDA后端Decoding的Attention我看起來好像還是會用Flash Attention?也許是我看錯了,這條暫時存疑。
在RWKV模型實現里,看到Batch維度寫死為1了,應該不支持動態Batch?這樣對于啟真實服務來說會有一些限制。
0x4.3 期待
如果短期內能讓一個對TVM只有輕度依賴的社區開發者新增一個新的模型。
如果模型存在自定義CUDA Kernel,需要一個詳細的教程來指引。
模型逐層打印來debug精度缺一個教程。
Paged Attention類似策略的引入。
動態Batch的支持。
暫時就想到這些,歡迎斧正。
編輯:黃飛
?
評論
查看更多