defkl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.FloatTensor: """Compute KL divergence given logprob and ref_logprob. Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1104 See more description in http://joschu.net/blog/kl-approx.html
Args: logprob: ref_logprob:
Returns:
""" if kl_penalty in ("kl", "k1"): return logprob - ref_logprob
if kl_penalty == "abs": return (logprob - ref_logprob).abs()
if kl_penalty in ("mse", "k2"): return0.5 * (logprob - ref_logprob).square()
# J. Schulman. Approximating kl divergence, 2020. # # URL http://joschu.net/blog/kl-approx.html. if kl_penalty in ("low_var_kl", "k3"): kl = ref_logprob - logprob ratio = torch.exp(kl) kld = (ratio - kl - 1).contiguous() return torch.clamp(kld, min=-10, max=10)
if kl_penalty == "full": # so, here logprob and ref_logprob should contain the logits for every token in vocabulary raise NotImplementedError