社区贡献 | Mixtral-8x7B Pytorch 实现
0.前言
本文从代码角度来谈下 Mixtral 8x7B
混合专家Pytorch
的实现
1.论文概述
Mixtral-8x7B
引爆了MoE
的技术方向,更多针对MoE
优化的Trick
出现,回归模型本身来解析:
-
Mixtral 8x7B
采用了sMoE
模型结构,模型的细节如何?路由负载均衡如何计算?代码如何实现? -
Mixtral 8x7B
的训练流程和推理流程是怎么样的,如何提高训练和推理效率? -
Mixtral 8x7B
的模型参数是如何计算的? -
Mixtral 8x7B
性能硬刚LLaMA2-70B
和GPT-3.5
, 性能一线水准,在MBPP
代码能力超越3.5
2. Mixtral 8x7B 模型架构和计算流程
Mixtral is based on a transformer architecture [31] and uses the same modifications as described in [18], with the notable exceptions that Mixtral supports a fully dense context length of 32k tokens, and the feed forward blocks are replaced by Mixture-of-Expert layers (Section 2.1). The model architecture parameters are summarized in Table 1.
-
base
的模型结构为Transformers
的改版Mistral-7B
-
MoE
作用在Feed Forward Blocks
上
2.1 Mixtral 模型架构
In a Transformer model, the MoE layer is applied independently per token and replaces the feed-forward (FFN) sub-block of the transformer block. For Mixtral we use the same SwiGLU architecture as the expert function Ei(x) and set K = 2. This means each token is routed to two SwiGLU sub-blocks with different sets of weights. Taking this all together, the output y for an input token x is computed as:
-
以 LLaMA2
或Mistral-7B
来说其MLP
都是SwiGLU
形式 -
在 Mixtral-8x7B
中 每层的Decoder
层的MLP
都以sMoE
来替换掉
Transformers Mixtral-of-Expert
代码实现:
在Huggingface
的Transformers
框架中, Mixtral
主要有两部分组成
-
MixtralDecoderLayer
-
MixtralSparseMoeBlock
:替换掉原有的MLP层
MixtralForCausalLM(
(model): MixtralModel(
(embed_tokens): Embedding(32000, 128)
(layers): ModuleList(
(1): MixtralDecoderLayer(
(self_attn): MixtralAttention(
(q_proj): Linear(in_features=128, out_features=128, bias=False)
(k_proj): Linear(in_features=128, out_features=128, bias=False)
(v_proj): Linear(in_features=128, out_features=128, bias=False)
(o_proj): Linear(in_features=128, out_features=128, bias=False)
(rotary_emb): MixtralRotaryEmbedding()
)
(block_sparse_moe): MixtralSparseMoeBlock(
(gate): Linear(in_features=128, out_features=8, bias=False)
(experts): ModuleList(
(0-7): 8 x MixtralBLockSparseTop2MLP(
(w1): Linear(in_features=128, out_features=256, bias=False)
(w2): Linear(in_features=256, out_features=128, bias=False)
(w3): Linear(in_features=128, out_features=256, bias=False)
(act_fn): SiLU()
)
)
)
(input_layernorm): MixtralRMSNorm()
(post_attention_layernorm): MixtralRMSNorm()
)
)
(norm): MixtralRMSNorm()
)
2.2 SMoE 层实现
2.2.1 单个 Expert 实现
import torch
from torch import nn
from transformers import MixtralConfig
class MixtralBLockSparseTop2MLP(nn.Module):
def __init__(self, config: MixtralConfig):
super().__init__()
self.ffn_dim = config.intermediate_size
self.hidden_dim = config.hidden_size
self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
self.act_fn = nn.SiLU()
# Forward 是 SwiGLU
def forward(self, hidden_states):
y = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
y = self.w2(y)
return y
x = torch.randn(1, 64, 128)
expert = MixtralBLockSparseTop2MLP(config)
print('单个专家为原LLaMA的MLP层')
print(expert)
g = expert(x)
print('单个专家输入:', x.shape)
print('单个专家输出结果:', g.shape)
结果
单个专家为原LLaMA的MLP层
MixtralBLockSparseTop2MLP(
(w1): Linear(in_features=128, out_features=256, bias=False)
(w2): Linear(in_features=256, out_features=128, bias=False)
(w3): Linear(in_features=128, out_features=256, bias=False)
(act_fn): SiLU()
)
单个专家输入:
torch.Size([1, 64, 128])
单个专家输出结果:
torch.Size([1, 64, 128])
2.2.2 混合Expert实现
class MixtralSparseMoeBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_dim = config.hidden_size
self.ffn_dim = config.intermediate_size
self.num_experts = config.num_local_experts
self.top_k = config.num_experts_per_tok
# gating
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
# 多个 SwiGLU MLP 层组成混合专家
self.experts = nn.ModuleList([MixtralBLockSparseTop2MLP(config) \
for _ in range(self.num_experts)])
x = torch.randn(1, 64, 128)
experts = MixtralSparseMoeBlock(config)
print('多个专家混合专家')
print(experts)
在以上我们实现了模型的关键结构, 但是这里的sMoE
的Forward
并没有实现
2.3 SMoE 计算流程
2.3.1 Gating流程
以下表示为多个token
的gating
计算流程
# 阶段一
# 计算稀疏 gating 值
tokens = 6
x = torch.randn(1, tokens, 128) # 6个token
hidden_states = x
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# 每层都会产生router_logits, 将用于最后作 load balance loss
router_logits = experts.gate(hidden_states)
print(f'experts.gate output router logits : \n {router_logits}')
# 计算 TopK 的 专家 logits 和 Top2 专家的位置
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
print(f'softmax weight : \n {routing_weights}')
routing_weights, selected_experts = torch.topk(routing_weights, \
experts.top_k, dim=-1)
print(f'expert select : \n {selected_experts}')
print(f'topk : \n {routing_weights}')
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
print(f'topk归一化 : \n {routing_weights}')
routing_weights = routing_weights.to(hidden_states.dtype)
## One Hot 编码
expert_mask = torch.nn.functional.one_hot(selected_experts, \
num_classes=experts.num_experts).permute(2, 1, 0)
for i in range(tokens):
print(f'【token_{i}】\n', expert_mask[:,:,i])
追踪x3
的结果
2.3.2 Expert 流程
-
sMoE
中是基于专家来选择token
来计算的 -
token
先序:左图为token3
选择expert 2
,expert 3
号来计算sMoE
结果 -
expert
先序:右图为依次计算expert2
和expert3
才得出token3
的sMoE
结果
代码实现结果为:
## 最终结果
final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim), \
dtype=hidden_states.dtype, device=hidden_states.device
)
print(f'final moe result shape for each token: {final_hidden_states.shape}')
# 每个专家收集需要计算token
for expert_idx in range(experts.num_experts):
print(f'--------expert {expert_idx} ---------')
expert_layer = experts.experts[expert_idx]
print(expert_mask[expert_idx])
idx, top_x = torch.where(expert_mask[expert_idx])
print(f'专家 {expert_idx} 计算的样本编号:',top_x.tolist()) # select x_idx for expert top1
print(f'专家 {expert_idx} top1:0, top2:1 ',idx.tolist()) # 0 is top1 ,1 is top2
print(f'有 {len(top_x)} / {x.shape[1]} token 选到专家 {expert_idx}')
top_x_list = top_x.tolist()
idx_list = idx.tolist()
current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
# expert_0(x) * routing_weights
current_hidden_states = expert_layer(current_state) \
* routing_weights[top_x_list, idx_list, None]
# 将计算的单个专家结果填入到结果表里
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
print(current_state.shape)
print(routing_weights[top_x_list, idx_list, None].shape)
print(current_hidden_states.shape)
print(final_hidden_states.shape)
输出结果为: