余弦退火算法动态调整学习率


余弦退火算法动态调整学习率

余弦退火算法(Cosine Annealing with Warm Restarts) 是一种学习率调度策略,通过周期性调整学习率,在训练过程中动态平衡优化过程的探索(exploration)和收敛(convergence)。以下是其核心原理和实现细节:


1. 核心思想

  • 周期性学习率衰减:学习率按余弦函数从最大值(初始值)下降到最小值(eta_min),形成一个“周期”。
  • 重启机制(Warm Restarts):每个周期结束后,学习率重新恢复到最大值(或部分恢复),然后再次衰减。重启可以避免陷入局部最优,增强模型泛化能力。

2. 数学公式

学习率在第 ( t ) 个 step 的计算公式为:
$$
\eta_t = \eta_{\text{min}} + \frac{1}{2} \cdot (\eta_{\text{max}} - \eta_{\text{min}}) \cdot \left(1 + \cos\left(\frac{T_{\text{cur}}}{T_i} \cdot \pi\right)\right)
$$

  • $\eta_{\text{max}}$:初始学习率(如 config.lr)。
  • $ \eta_{\text{min}}$:学习率下限(如 1e-5)。
  • $T_{\text{cur}}$:当前周期内的步数(step)。
  • $T_i$:第 ( i ) 个周期的长度(随重启次数增加而倍增)。

3. 关键参数

  • **T_0**:初始周期长度(单位:epoch 或 step)。
  • **T_mult**:周期倍增因子(如 T_mult=2,则每个周期长度是前一个的 2 倍)。
  • **eta_min**:学习率下限(防止学习率过低导致训练停滞)。

4. 代码实现(PyTorch)

from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

# 定义优化器
optimizer = Adam(model.parameters(), lr=1e-3)

# 定义调度器
scheduler = CosineAnnealingWarmRestarts(
    optimizer, 
    T_0=10,       # 初始周期长度(如 10 个 epoch)
    T_mult=2,     # 周期倍增因子
    eta_min=0     # 学习率下限
)

# 训练循环中更新学习率
for epoch in range(epochs):
    for batch in dataloader:
        train(...)
        scheduler.step()  # 每个 batch 更新学习率(也可每 epoch 更新)

5. 你的代码中的应用

在你的 Trainer 类中,调度器配置如下:

self.lr_scheduler = CosineAnnealingWarmRestarts(
    self.optimizer, 
    T_0=10,       # 初始周期为 10 个 epoch
    T_mult=2,     # 每个周期长度翻倍(10 → 20 → 40 → ...)
    eta_min=0     # 学习率下限为 0
)
  • 作用:在训练初期快速收敛,随后通过重启机制跳出局部最优,提升模型鲁棒性。
  • 配合训练策略:在每个周期结束时重启学习率,结合模型快照(snapshot)保存,可实现多模型集成(如 Snapshot Ensemble)。

7. 优势与适用场景

  • 优势
    • 避免陷入局部最优,增强模型泛化能力。
    • 通过周期性高学习率,重新探索参数空间。
  • 适用场景
    • 训练初期需要快速收敛,后期需要精细调整。
    • 数据集复杂、模型容易过拟合的任务(如街景字符识别)。

Author: qwq小小舒
Reprint policy: All articles in this blog are used except for special statements CC BY 4.0 reprint policy. If reproduced, please indicate source qwq小小舒 !
  TOC