SARSA 是一种经典的 TD 算法。SARSA 是 state-action-reward-state-action 的缩写,原因是它使用 (st,at,rt,st+1,at+1) 这样的五元组来进行更新。
SARSA 的 TD 目标
SARSA 可以由下面的贝尔曼方程推导出来:
Qπ(st,at)=ESt+1,At+1[Rt+γ⋅Qπ(St+1,At+1)∣St=st,At=at]
对方程两边作近似:
- 使用函数估计器 q (表格或神经网络) 来估计 Qπ
- 将期望近似为抽样。具体来说,给定当前状态 st 和动作 at,我们可以通过执行动作 at 来获得奖励 rt 和下一个状态 st+1,然后基于 st+1 抽样得到新动作 a~t+1∼π(⋅∣st+1),得到 rt+γQπ(st+1,a~t+1)≈rt+γq(st+1,a~t+1)
于是我们得到 SARSA 的 TD 目标:
y^t=rt+γq(st+1,a~t+1)
更新规则
q(st,at) 和 y^t 都是对于 Qπ(st,at) 的估计,但是 y^t 是一个更好的估计,因为它使用了更多的信息。所以我们鼓励函数估计器去逼近 y^t。
对于学习率为 α 的表格型 SARSA:
q(st,at)=(1−α)q(st,at)+αy^t
对于学习率为 α 的使用神经网络 (参数记为 w ) 的 SARSA:
L(w)ΔwL(w)w=21[q(st,at;w)−y^t]2=(q(st,at;w)−y^t)∇wq(st,at;w)←w−α(q(st,at;w)−y^t)∇wq(st,at;w)
训练流程
将函数估计器设为 qnow,当前策略为 πnow
- 观察当前状态 st
- 根据当前策略 πnow 采样动作 at∼πnow(⋅∣st)
- 计算 q^t=qnow(st,at)
- 执行动作 at,观察奖励 rt 和下一个状态 st+1
- 根据当前策略 πnow 采样动作 a~t+1∼πnow(⋅∣st+1)
- 计算 q^t+1=qnow(st+1,a~t+1)
- 计算 TD 目标 y^t=rt+γq^t+1
- 按照更新规则更新函数估计器 qnow
- 更新策略,注意策略的生成方式与 SARSA 无关