LLMs Finetune系列(五) 讲解了 RAFT 全量参数对齐微调算法,RAFT 算法思路和Reject Sample 思路其实大差不差,基本可以认为只是换了一个叫法,本文讲解Llama2模型中全量参数对齐微调算法Reject Sample + PPO(原版论文中V5 实验效果最好,采用的方案),具体和前面讲解的RLHF 算法的区别如何,这里先上图:
上图微调算法和RLHF 中不一致的部分做了明显的标识,具体体现在两个方面:
1.奖励 Reward 函数设置
2.Finetune 过程中添加Reject Sample 算法,也即RAFT 算法。
LLMs 对齐过程中,让大语言模型的无害性,有帮助性;针对这个问题,Llama2 训练了分别训练了两个奖励模型模型,一个是对无害性的奖励,一个是对帮助性奖励;在误差函数中添加新增了一个边际(margin)标签,significantly better、better、slightly better、negligibly better、unsure,这里和instruct-GPT不一样的是,对于每个prompt 只生成了两个结果(为了保证多样性,使用 2 个 model 生成 response,并使用不同的 temperature),而不是像instruct-GPT 原理图上画的生成A、B、C、D 四个结果,进行排序。
具体在RLHF过程中,对两个奖励函数的使用如下:
PPO 的求解目标找到期望奖励最大时候的策略函数,也即LLMs
这里LOGIT 是sigmoid 函数的反向操作,WHITEN 操作是在样本批量维度,对奖励值进行标准化,查询到相关开源代码,其整体思路就是 (x-mean/std) 的变换:
def whiten(xs: torch.Tensor, shift_mean=True, distributed=True, group=None) -> torch.Tensor:
"""Whitens values"""
if distributed and dist.is_initialized():
mean, var, _ = get_global_statistics(xs, group=group)
else:
var, mean = torch.var_mean(xs)
whitened = (xs - mean) * torch.rsqrt(var + 1e-8)
if not shift_mean:
whitened += mean
return whitened
2.Reject Sample + PPO 算法交替进行
原文中讲述了这两种算法的不同:
显著区别具体来讲Reject Sample 偏重于广度,PPO 算法偏重于深度
广度——在拒绝采样中,模型对给定的提示探索 K 个样本,进行finetune,而 PPO 只进行一次生成。
深度——在 PPO 中,在训练的第 t 步,样本是更新后的模型策略的函数,该策略来自前一步的梯度更新后的 t-1。在拒绝采样微调中,在应用类似于SFT的微调之前,根据模型的初始策略采样所有输出以收集新数据集。然而,由于我们应用了迭代模型更新,所以两种 RL 算法之间的基本差异不太明显。
从前面给出的图中,Llama2 对prompt 进行抽样,对每个样本生成多个输出,然后对一个prompt,多个生成gi 对,进行Reject Sample 微调,然后对每个Prompt 中的某一个生成g,进行PPO微调。
其实整个Llama2 算法在Reward 模型上做了较大的改进,整个过程的核心还是PPO 算法,而PPO 算法核心是Critic 网络和action(policy 策略网络的参数更新),在LLM 的应用上多了Reward 和 SFT 两个原始训练好的网络
页面更新:2024-02-11
本站资料均由网友自行发布提供,仅用于学习交流。如有版权问题,请与我联系,QQ:4156828
© CopyRight 2020-2024 All Rights Reserved. Powered By 71396.com 闽ICP备11008920号-4
闽公网安备35020302034903号