广哥在硅谷 · 用思考抵达清晰,用行动靠近自由
JMLR 2022 · GOOGLE

Switch Transformers
万亿参数 + Top-1 路由的极简哲学

Scaling to Trillion Parameter Models · Curated Chinese Edition
Fedus · Zoph · Shazeer · Google 中文精校版 8 章 · 约 25 分钟阅读
📋 内容摘要

2021 年,William Fedus、Barret Zoph、Noam Shazeer(Google)发表 Switch Transformer——MoE 谱系第三个里程碑。核心想法只有一个——把 MoE 路由从 top-k 简化到 top-1。Shazeer 2017 说"必须 $k > 1$ 才能有有意义的梯度"——Switch 论文反驳这个假设,证明 k=1 不仅可行而且更好。简化路由 + 选择性 bfloat16 精度 + 更小初始化 + 专家正则化——让稀疏模型规模到万亿参数比 T5-XXL 快 4 倍,在 101 个语言上全面提升。蒸馏把万亿参数稀疏模型压缩 99%到 dense 模型,保留 30% 质量增益

章节目录
  1. 引言 · 第四个扩展轴
    The fourth scaling axis
  2. Switch Routing · 反 Shazeer 2017
    Top-1 instead of top-k
  3. 专家容量 + Token Dropping
    Expert capacity
  4. 可微负载平衡损失
    Differentiable load balance
  5. 选择性 bfloat16 精度
    Selective precision
  6. 扩展性质 · 7 倍预训练加速
    7x pretraining speedup
  7. 蒸馏 · 99% 压缩,30% 增益保留
    Distillation
  8. 万亿参数 · 4 倍快过 T5-XXL
    Trillion-parameter Switch
CHAPTER 01 · INTRODUCTION

引言 · 第四个扩展轴

The fourth scaling axis
📌 本节核心要点

Kaplan 2020 的 scaling law 揭示了三个扩展轴——模型大小、数据集大小、计算预算。Switch Transformer 提出第四个轴——"在保持每样本 FLOPs 不变的同时增加参数数量"。假设是——参数数量本身就是独立的重要扩展维度。Switch 通过"稀疏激活的模型"实现——它高效利用 GPU/TPU 这种为稠密矩阵乘法设计的硬件。分布式训练里——稀疏激活的层把唯一权重分散到不同设备模型权重随设备数增加,但每设备的内存和计算占用保持可管理

Kaplan 2020 4th Scaling Axis TPU Optimized

2021 年 1 月(arXiv 提交日期),William Fedus、Barret Zoph、Noam Shazeer——同一个 Noam Shazeer,2017 年是Sparsely-Gated MoE的一作,也是Transformer的核心作者——在 Google 发表《Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity》。

论文摘要直接——

深度学习里,模型对所有输入复用同样参数

MoE 模型违背这一点——
给每个输入选不同参数——
结果是一个稀疏激活的、参数离谱地多的、计算成本恒定的模型

尽管 MoE 有几个显著成功
广泛采用被复杂性、通信成本、训练不稳定性阻碍

我们引入 Switch Transformer 解决这些。

— FEDUS, ZOPH, SHAZEER, 2021

Switch 的指导原则——"用简单和计算高效的方式最大化 Transformer 参数数"。

Kaplan 2020 的 scaling law 揭示了三个扩展轴——模型大小、数据集大小、计算预算。Switch 提出第四个轴——

我们探索第四个轴——
保持每样本 FLOPs 恒定的同时
增加参数数量

我们的假设是——
参数数量,独立于总计算——
是一个独立的重要扩展维度

— SWITCH TRANSFORMER

Switch 通过稀疏激活模型实现——"高效利用为稠密矩阵乘法设计的 GPU/TPU 硬件"。论文的核心贡献——

· Switch Transformer 架构——简化并改进 MoE
· vs T5 的扩展性质——7x+ 预训练加速,同样 FLOPS/token
· 蒸馏成功——99% 压缩 + 30% 质量增益保留
· 改进训练技巧——选择性 bfloat16 精度、更小初始化、更强专家正则
· 多语言全面提升——101 个语言91% 的语言在 mT5 上有 4x+ 加速
· 万亿参数模型——T5-XXL 强基线上4 倍加速

广哥在硅谷◆ ◆ ◆
CHAPTER 02 · SWITCH ROUTING

Switch Routing · 反 Shazeer 2017

Top-1 instead of top-k
📌 本节核心要点

Shazeer 2017 主张必须 $k > 1$——"学习路由如果没有比较至少两个专家的能力就不能工作"。Ramachandran-Le 2018 进一步研究 top-k——发现底层用更大 k 重要。Switch 反驳——"我们用一个简化策略,只路由到一个专家"。三大好处——(1) 路由计算减少(2) 每个专家的批量至少减半(3) 路由实现简化、通信成本降低。这是反直觉的洞察——MoE 越简单越好。

k = 1 Anti-Shazeer 2017 3 Benefits

Shazeer 2017 的 MoE 路由——给定 token $x$,路由到 top-k 专家。router 变量 $W_r$ 产生 logits $h(x) = W_r \cdot x$,过 softmax 归一化——

$$p_i(x) = \frac{e^{h(x)_i}}{\sum_j^N e^{h(x)_j}} \quad (1)$$ $$y = \sum_{i \in T} p_i(x) E_i(x) \quad (2)$$

其中 $T$ 是选中的 top-k 索引集合。Shazeer 2017 主张——"路由到 $k > 1$ 专家是必要的——为了有意义的梯度"。直觉——学习路由如果没有比较至少两个专家的能力就不能工作。Ramachandran-Le 2018 进一步研究——发现底层用更大 k 重要

Switch 论文反驳这两个观察——

相反——我们用简化策略——
只路由到一个专家

我们证明这个简化保留模型质量
减少路由计算表现更好

这个 $k = 1$ 路由策略被称为 Switch 层

— SWITCH ROUTING

"Switch" 这个名字的含义——路由器只是"开关"——把 token 切换到一个专家

Switch 层的三个好处——

⚡ 三大好处

注意——MoE 和 Switch 路由用 $p_i(x)$ 作为乘子(Eq. 2 里的 gate value)——让路由器可微。这是关键——梯度仍然能流过路由器

这是 Switch 最美的地方——反直觉,但简单战胜复杂。MoE 越简单越好——这是 Switch 给出的哲学

广哥在硅谷◆ ◆ ◆
CHAPTER 03 · CAPACITY

专家容量 + Token Dropping

Expert capacity
📌 本节核心要点

TPU 需要静态声明大小——所以 Switch 的张量形状编译时确定,但路由动态专家容量解决这个矛盾——把 token 数平均分给专家,再乘容量因子(capacity factor)——比 1.0 大留缓冲。如果某专家收到的 token 超过容量——"dropped token"——计算被跳过,token 表示通过残差连接直接传到下一层。增大容量因子减少 token drop,但浪费计算和内存。Switch 论文发现——dropped token 率低于 1%不依赖专家数量

Switch 论文用 Mesh-TensorFlow——为分布式数据 + 模型并行设计的库。模型为 TPU 设计——需要静态声明大小

但 Switch 路由动态——不同 batch 不同 token 会路由到不同专家。这两件事如何统一?

解法——专家容量——

$$\text{expert capacity} = \left( \frac{\text{tokens per batch}}{\text{number of experts}} \right) \times \text{capacity factor} \quad (3)$$

专家容量是每个专家计算的 token 数——预先设定。容量因子大于 1.0 留缓冲,应付 token 分配不完美平衡。

如果一个专家收到的 token 超过容量——这些 token 被称为 "dropped tokens"——计算被跳过token 表示通过残差连接直接传到下一层

容量因子的权衡——

· 太小——很多 token 被 drop——模型质量下降
· 太大——浪费计算和内存(很多缓冲空槽)

Switch 论文发现——用负载平衡损失(下一章)配合足够高的系数dropped token 率通常 < 1%,并且不依赖专家数——这意味着负载平衡有效

广哥在硅谷◆ ◆ ◆
CHAPTER 04 · LOAD BALANCE

可微负载平衡损失

Differentiable load balance loss
📌 本节核心要点

Switch 简化了 Shazeer 2017 的"两个分离损失"(importance + load)——合成一个辅助损失。给定 $N$ 个专家、$T$ 个 token 的批量——损失是"实际分配比例 $f_i$"和"路由概率比例 $P_i$"的缩放点积$f_i$——分到专家 $i$ 的 token 比例(不可微)。$P_i$——批内分给专家 $i$ 的路由概率比例(可微)。两者都希望 = $1/N$(均匀路由)。损失在均匀分布下最小化$P$ 可微所以梯度能流回。系数 $\alpha = 10^{-2}$——"既保证平衡又不淹没主损失"

Switch 简化了 Shazeer 2017 的原始设计——把负载平衡损失importance weighting 损失合成一个辅助损失

给定 $N$ 个专家(索引 $i = 1$ 到 $N$)、批量 $B$ 含 $T$ 个 token——辅助损失为缩放点积——

$$\text{loss} = \alpha \cdot N \cdot \sum_{i=1}^N f_i \cdot P_i \quad (4)$$

其中——

· $f_i$实际分配到专家 $i$ 的 token 比例——

$$f_i = \frac{1}{T} \sum_{x \in B} \mathbb{1}\{\text{argmax } p(x) = i\} \quad (5)$$

· $P_i$批内分给专家 $i$ 的路由概率比例——

$$P_i = \frac{1}{T} \sum_{x \in B} p_i(x) \quad (6)$$

我们想要批内 token 均匀分到 $N$ 个专家——所以希望两个向量都 = $1/N$。Eq. 4 的辅助损失鼓励均匀路由——它在均匀分布下最小化。

注意——$P$ 可微,$f$ 不可微。但 $f \cdot P$ 通过 $P$ 可以求导——梯度能流回路由器。

损失乘以专家数 $N$ ——保持损失值随专家数变化而恒定。最后——超参 $\alpha$ 是辅助损失系数——Switch 论文用 $\alpha = 10^{-2}$——"足够大保证负载平衡,足够小不淹没主交叉熵目标"。

广哥在硅谷◆ ◆ ◆
CHAPTER 05 · BFLOAT16

选择性 bfloat16 精度

Selective precision
📌 本节核心要点

稀疏专家模型有训练不稳定性问题——硬切换决策让 softmax 计算对 bfloat16 精度敏感。GShard(Lepikhin 2020)不得不用 float32 训练——但通信成本巨大。Switch 论文展示——在模型的"局部区域选择性 cast 到 float32" 就能保持稳定性,不付出跨设备通信 float32 张量的代价。具体做法——把路由器输入 cast 到 float32路由计算在 float32 里做函数末尾把 dispatch/combine 张量 cast 回 bfloat16。结果——近乎 bfloat16 的速度 + float32 的训练稳定性

稀疏专家模型可能引入普通 Transformer 没有的训练困难——硬切换路由决策每层都做,可能造成不稳定。低精度格式 bfloat16 在路由 softmax 计算里会加剧问题

GShard(Lepikhin 2020)的解法——整个 MoE Transformer 都用 float32 训练。但 float32 比 bfloat16慢、占内存大、跨设备通信成本高

Switch 的解法——选择性精度——

我们展示——
通过在模型的局部部分选择性 cast 到 float32 精度
稳定性可以达成——
而不付出float32 张量跨设备通信的昂贵成本

— SELECTIVE PRECISION

具体做法——把路由器输入 cast 到 float32。路由函数接收 token,产生"dispatch 和 combine 张量"用于专家计算的选择和重组。float32 精度只在路由函数内部使用——在本地设备上的计算

因为函数末尾,dispatch 和 combine 张量被cast 回 bfloat16——不需要昂贵的 float32 张量跨设备通信

结果——"接近 bfloat16 的速度 + float32 的训练稳定性"

这是 Switch 的关键工程贡献——稀疏模型第一次可以稳定地用 bfloat16 训练。这解锁了万亿参数——因为float32 训练万亿参数太贵

广哥在硅谷◆ ◆ ◆
CHAPTER 06 · SCALING

扩展性质 · 7 倍预训练加速

7x pretraining speedup
📌 本节核心要点

Switch 在 C4(Colossal Clean Crawled Corpus)上做 mask LM 预训练。Switch vs T5(dense)头对头比较——同样 FLOPs/token、同样硬件、同样训练步数。三大发现——(1) Switch 在速度-质量上同时打败 dense 和 MoE Transformer(2) Switch 计算占用比 MoE 小(3) Switch 在低容量因子(1.0, 1.25)下表现更好——更适合大模型内存稀缺的场景Switch-Base 1.0 容量因子达到质量阈值用 62.8 小时,T5-Large 用 131.1 小时——2 倍快。在 mT5 baseline 上 91% 的 101 个语言得到4x+ 加速

7x Speedup 101 Languages 91% with 4x+ gain

Switch Transformer 的第一个测试——在 C4(Colossal Clean Crawled Corpus) 上做mask LM 预训练。用负对数 perplexity作为指标。

Switch vs MoE Transformer 头对头比较 (Table 1)——

模型容量因子100k 步质量 ↑到质量阈值 -1.50 时间 ↓速度
T5-Base-1.731未达到1600
T5-Large-1.550131.1 h470
MoE-Base1.0-1.57280.1 h860
Switch-Base1.0-1.56162.8 h1000
Switch-Base+1.0-1.53467.6 h780

三大发现——

(1) Switch 在速度-质量上同时打败 dense 和 MoE——固定计算和时间,Switch 取得最好结果。

(2) Switch 计算占用比 MoE 小——把 Switch 扩大到匹配 MoE 训练速度时(Switch-Base+),每步基础上也击败所有 MoE 和 Dense

(3) Switch 在低容量因子(1.0, 1.25)下更好——更适合大模型内存稀缺的场景

论文还测了多语言学习——在 101 个语言上和 mT5-Base 对比。所有 101 个语言都得到普遍提升91% 的语言在 mT5 baseline 上得到 4x+ 加速

广哥在硅谷◆ ◆ ◆
CHAPTER 07 · DISTILLATION

蒸馏 · 99% 压缩,30% 增益保留

Distillation
📌 本节核心要点

Switch 论文做了一件让我读完后惊讶的事——把稀疏预训练 + 专门化微调的模型成功蒸馏小 dense 模型。"模型大小减少最多 99%,同时保留大稀疏教师 30% 的质量增益"。这是"训练用 MoE,推理用 dense"的有趣范式——训练时用稀疏的便宜容量,推理时用 dense 的部署简便。这给边缘部署、移动设备开了一条路。

99% Compression 30% Gain Preserved Sparse → Dense

Switch 论文做了一件让我特别注意的事——蒸馏

大稀疏模型有部署难题——万亿参数不适合每台设备都装一份。蒸馏可以解决——

我们成功蒸馏稀疏预训练模型和专门化微调模型——
小 dense 模型

我们把模型大小最多减少 99%——
同时保留大稀疏教师 30% 的质量增益

— SWITCH TRANSFORMER DISTILLATION

这是"训练用稀疏,推理用 dense"新范式——

· 训练时——用 Switch 的便宜容量训出一个大稀疏模型
· 蒸馏时——把它的知识压缩到小 dense 模型
· 推理时——只部署 dense 模型——每台设备都能跑

"30% 质量增益保留"听起来不多——但对于 99% 压缩比来说是巨大的。例如——如果大稀疏模型比 dense baseline 提升 100%,那压缩后的小模型保留 30%——仍然显著好于原始 dense baseline

这给边缘部署、移动设备、低延迟服务开了一条路。"训练用云,推理用本地"——这个范式后来被很多 LLM 工作借鉴。

广哥在硅谷◆ ◆ ◆
CHAPTER 08 · TRILLION

万亿参数 · 4 倍快过 T5-XXL

Trillion-parameter Switch
📌 本节核心要点

论文最后展示万亿参数 Switch Transformer。结合数据并行 + 模型并行 + 专家并行三种并行——达到约 1 万亿参数。在 C4 上预训练——对比 T5-XXL 强基线(11 billion 参数 dense)——同时间到同质量阈值,Switch 快 4 倍。这是 MoE 谱系里的里程碑——稀疏激活模型第一次达到"参数千万亿级别"从 1991 的几个专家到 2021 的万亿参数——30 年——MoE 完成了从"小众思想"到"前沿 LLM 必备"的完整转变

~1 Trillion Params 4x faster than T5-XXL 3-way Parallelism

论文最后一节展示万亿参数 Switch Transformer——"把神经语言模型的规模又推进一步"

结合三种并行——

· 数据并行——多设备处理不同 batch
· 模型并行——一个层切到多设备
· 专家并行——不同专家放不同设备

组合让模型规模达到约 1 万亿参数

实验——vs T5-XXL(11 billion 参数 dense)——在 C4 上预训练。结果——同时间到同质量阈值,Switch 比 T5-XXL 快 4 倍

这是 MoE 谱系里的里程碑——稀疏激活模型第一次正式达到"参数万亿级"

1991 的几个专家2017 的 137B 参数2021 的万亿参数——30 年里 MoE 完成了从"小众思想"到"前沿 LLM 必备"的完整转变

Switch 之后——2021 GLaM(Google 1.2T MoE LLM)2023 GPT-4(被广泛认为是 MoE)2024 Mixtral 8x7B(第一个流行开源 MoE)——每一个都建立在 Switch 奠定的工程范式上

2021 年 1 月——
Switch Transformer 第一次达到万亿参数

2024 年——
Mixtral 8x7B 把同样的范式开源给所有人

2017 年的 Noam Shazeer——
2021 年的 Fedus + Zoph + Shazeer——
1991 年的 Jacobs + Jordan + Hinton——

34 年——同一个想法的进化

— THE MOE LINEAGE