实现数据集的奖励函数 ====================================== 最后更新:2025年6月2日。 对于每个数据集,我们需要实现一个奖励函数或利用一个奖励模型来计算生成响应的奖励。 我们在 `reward_score 目录 `_ 中预先实现了一些奖励函数。 您也可以使用自定义的奖励函数。 目前,我们支持 GSM8k 和 MATH 数据集的奖励函数。对于 RLHF 数据集(例如 `full_hh_rlhf`)和代码生成(例如 APPS),我们分别使用奖励模型和 SandBox(即将开源)进行评估。 RewardManager ------------- 在 PPO 后续训练脚本 `main_ppo.py `_ 的入口处,我们实现了一个 ``RewardManager``,它利用预先实现的奖励函数来计算每个响应的分数。 在 ``RewardManager`` 中,我们实现了一个 ``__call__`` 函数来计算每个响应的分数。 所有奖励函数都由 ``compute_score_fn`` 执行。 输入是一个 ``DataProto``,其中包含: - ``input_ids``, ``attention_mask``:应用 `chat_template` 后的 ``input_ids`` 和 ``attention_mask``,包括 prompt 和 response。 - ``responses``:响应 token。 - ``ground_truth``:当前 prompt 的真实答案字符串。存储在 ``DataProto`` 的 ``non_tensor_batch`` 中,应该在 parquet 文件中进行预处理。 - ``data_source``:当前 prompt 的数据集名称。存储在 ``DataProto`` 的 ``non_tensor_batch`` 中,应该在 parquet 文件中进行预处理。 在对响应进行反 token 化(detokenize)后,响应字符串和真实答案字符串将作为输入传递给 ``compute_score_fn`` 来计算每个响应的分数。 奖励函数 ---------------- 预先实现 ~~~~~~~~~~~~~~~ 我们在 `reward_score 目录 `_ 中预先实现了一些奖励函数。 - 在 `GSM8k 示例 `_ 中,我们强制响应在四个 `####` 之后输出最终答案,然后使用字符串匹配与真实答案进行比较。如果完全正确,得1分;如果格式正确,得0.1分;如果格式不正确,得0分。 - 在 `MATH 示例 `_ 中,我们遵循 `lm-evaluation-harness 仓库 `_ 中的实现。 自定义 ~~~~~~~~~~ 您可以在一个单独的文件中实现自定义奖励函数,并通过 ``custom_reward_function.path`` 和 ``custom_reward_function.name`` 来指定它们。关于它们的集合,请参阅 :ref:`config-explain-page`。 您的奖励函数的参数应该是 ``data_source``、``solution_str``、``ground_truth`` 和 ``extra_info``。 例如: .. code:: python def my_reward_fn(data_source, solution_str, ground_truth, extra_info=None): return len(solution_str)/100 如果您只测试单个自定义奖励函数,您可以简单地将其命名为 'compute_score',并将 ``custom_reward_function.name`` 留空。 要运行多个使用不同自定义奖励函数的测试,您可以为每次试验修改 ``custom_reward_function.path`` 和 ``custom_reward_function.name``。 例如,您可以创建一个名为 `my_reward.py` 的文件,并在其中实现多个奖励函数。这样,对于不同的试验,您只需要调整 ``custom_reward_function.name``,从而在脚本中进行多次测试会更加方便。