Last updated on May 28, 2025 pm
1 2 3 4 5
| def entropy_from_logits(logits: torch.Tensor): """Calculate entropy from logits.""" pd = torch.nn.functional.softmax(logits, dim=-1) entropy = torch.logsumexp(logits, dim=-1) - torch.sum(pd * logits, dim=-1) return entropy
|
在查看verl代码时,发现其计算entropy的时候,使用了一种独特的方式,具体而言,这里的entropy_from_logits
是一种数值稳定写法,等价于熵的经典定义:
H(p)=−i∑pilogpi
具体而言,将第 i 个类的 logits
记作 ℓi,其概率即为:
pi=∑jeℓjeℓi
也就是 softmax
后的 logits
,那么
H(p)=−i∑pilogpi=−i∑pi(ℓi−logZ)
其中 Z=∑jeℓj,展开得到:
H(p)=−i∑piℓi+(logZ)i∑pi=−i∑piℓi+logZ
因为 ∑ipi=1,而
logZ=logsumexp({ℓi})
于是即可得到上面代码。
为什么数值稳定
1
| entropy = logsumexp(logits, dim=-1) - (softmax(logits)*logits).sum(dim=-1)
|
之所以称这是一种“数值稳定”的熵计算方式,主要有以下几点原因:
1. 避免直接算 logpi 的不稳定(核心原因)
最直观的香农熵写法是公式1,如果先做:
1 2 3
| p = softmax(logits) logp = torch.log(p) entropy = -(p * logp).sum(-1)
|
- 问题1: 当某些
logits
很负时,对应的 pi 会非常接近 0,这时 logpi 会变成一个很大的负数,p_i * logp_i
可能因为 “0×(−∞)” 导致数值上出现 NaN
- 问题2: 计算
log(p)
也会把下溢(underflow)的微小概率映射到 −∞,再做乘法和求和十分容易出错。
IEEE 754浮点标准里,log(0)
会被定义为-∞
,是一个合法的无穷值,而不是NaN;所以如果此时做0×(−∞)
,就成为一个不定式,结果就是NaN
2. 用 log-sum-exp 计算 logZ 的稳定性
在等式4中,我们只需要两个量:
- logZ,即
logsumexp(logits)
- ∑ipiℓi,即
(softmax(logits)*logits).sum()
其中 PyTorch 的 logsumexp
会自动做log∑ieℓi=m+log∑ieℓi−m,
先减掉 m=maxjℓj 再 exponentiate,避免了 eℓi 可能的上溢(overflow)
3. 避免大范围指数/对数运算
softmax(logits)
本身底层也会先减 max 再做 exp()
,保证数值不爆
logsumexp
也是同样的 trick
- 于是整个式子中不会出现 “先指数化得到极大或极小值,再取对数” 这种前后相抵但中间溢出的操作