余弦退火算法动态调整学习率
余弦退火算法(Cosine Annealing with Warm Restarts) 是一种学习率调度策略,通过周期性调整学习率,在训练过程中动态平衡优化过程的探索(exploration)和收敛(convergence)。以下是其核心原理和实现细节:
1. 核心思想
- 周期性学习率衰减:学习率按余弦函数从最大值(初始值)下降到最小值(
eta_min
),形成一个“周期”。 - 重启机制(Warm Restarts):每个周期结束后,学习率重新恢复到最大值(或部分恢复),然后再次衰减。重启可以避免陷入局部最优,增强模型泛化能力。
2. 数学公式
学习率在第 ( t ) 个 step 的计算公式为:
$$
\eta_t = \eta_{\text{min}} + \frac{1}{2} \cdot (\eta_{\text{max}} - \eta_{\text{min}}) \cdot \left(1 + \cos\left(\frac{T_{\text{cur}}}{T_i} \cdot \pi\right)\right)
$$
- $\eta_{\text{max}}$:初始学习率(如
config.lr
)。 - $ \eta_{\text{min}}$:学习率下限(如
1e-5
)。 - $T_{\text{cur}}$:当前周期内的步数(step)。
- $T_i$:第 ( i ) 个周期的长度(随重启次数增加而倍增)。
3. 关键参数
- **
T_0
**:初始周期长度(单位:epoch 或 step)。 - **
T_mult
**:周期倍增因子(如T_mult=2
,则每个周期长度是前一个的 2 倍)。 - **
eta_min
**:学习率下限(防止学习率过低导致训练停滞)。
4. 代码实现(PyTorch)
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
# 定义优化器
optimizer = Adam(model.parameters(), lr=1e-3)
# 定义调度器
scheduler = CosineAnnealingWarmRestarts(
optimizer,
T_0=10, # 初始周期长度(如 10 个 epoch)
T_mult=2, # 周期倍增因子
eta_min=0 # 学习率下限
)
# 训练循环中更新学习率
for epoch in range(epochs):
for batch in dataloader:
train(...)
scheduler.step() # 每个 batch 更新学习率(也可每 epoch 更新)
5. 你的代码中的应用
在你的 Trainer
类中,调度器配置如下:
self.lr_scheduler = CosineAnnealingWarmRestarts(
self.optimizer,
T_0=10, # 初始周期为 10 个 epoch
T_mult=2, # 每个周期长度翻倍(10 → 20 → 40 → ...)
eta_min=0 # 学习率下限为 0
)
- 作用:在训练初期快速收敛,随后通过重启机制跳出局部最优,提升模型鲁棒性。
- 配合训练策略:在每个周期结束时重启学习率,结合模型快照(snapshot)保存,可实现多模型集成(如 Snapshot Ensemble)。
7. 优势与适用场景
- 优势:
- 避免陷入局部最优,增强模型泛化能力。
- 通过周期性高学习率,重新探索参数空间。
- 适用场景:
- 训练初期需要快速收敛,后期需要精细调整。
- 数据集复杂、模型容易过拟合的任务(如街景字符识别)。