一、什么是Batch Normalization?
Batch Normalization(简称BN)是在每个训练批次(batch)内,对网络中间层的激活值进行规范化(归一化),使它们具有均值为0、方差为1的分布。其核心思想是缓解“内部协变量偏移”(Internal Covariate Shift),即网络层输入分布的变化。
二、为什么需要Batch Normalization?
- 加快训练速度:减少对初始化和学习率的敏感性,加快模型收敛。
- 改善梯度传播:缓解梯度消失和梯度爆炸问题。
- 提升模型性能:允许使用更高的学习率,甚至在深层网络中实现更优表现。
- 一定的正则化效果:因加入的小批量噪声,具有一定的正则化作用。
三、Batch Normalization的工作原理和流程
1. 归一化
在每一层中,对当前批次的数据进行以下操作:
- 计算批次的均值:
- 计算批次的标准差:
然后,将激活值归一化:
其中,是一个小值,用于防止除零。
2. 缩放和平移
为了恢复模型的表达能力,BN会引入两个可学习的参数:
- 缩放参数:
- 移动参数:
最终输出:
这样,网络可以自主学习到合适的归一化尺度和偏移。
四、数学公式具体描述
给定一个批次 ,每个
是一组激活值(比如一个样本的某一层所有神经元的输出)。
- 批次平均值:
- 批次方差:
- 归一化:
- 最终输出:
其中, 和
是可以学习的参数。
五、训练和推理的区别
- 训练阶段:使用每个批次的均值和方差进行标准化。
- 推理阶段:使用训练过程中累计的全局平均值和方差(滑动平均)进行标准化,以确保测试时的一致性。
六、优缺点
优点
- 提升训练速度,缩短训练时间。
- 改善深层网络的训练稳定性。
- 允许使用更高的学习率。
- 在一定程度上具有正则化作用。
缺点
- 增加了计算和内存开销。
- 在批次大小很小时效果较差(因为估计的均值和方差不够稳定)。
- 在某些模型中可能会引入额外的复杂性。
七、应用场景
- 主要应用于卷积神经网络(CNN)、全连接网络(MLP)、Transformer等。
- 适合在训练时使用,推理时用估算的统计量替代。