预处理训练后数据 ======================================== 最后更新时间:2025年02月09日。 在开始训练后任务之前,我们需要为策略训练准备数据。数据应以 parquet 格式存储。 我们提供了几个针对不同数据集(包括 GSM8K、MATH、HelloSwag、Full_hh_rlhf)的数据预处理脚本。要准备其他数据集,我们需要遵循以下步骤:数据预处理脚本可以分为两部分: 1. 第一部分是通用部分,它从 huggingface 的 ``datasets`` 包加载数据集。然后使用 ``make_map_fn`` 预处理数据集,并将其存储为 parquet 格式。 .. code:: python import re import os import datasets from verl.utils.hdfs_io import copy, makedirs import argparse # To extract the solution for each prompts in the dataset # def extract_solution(solution_str): # ... if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--local_dir', default='/opt/tiger/gsm8k') parser.add_argument('--hdfs_dir', default=None) args = parser.parse_args() num_few_shot = 5 data_source = 'openai/gsm8k' dataset = datasets.load_dataset(data_source, 'main') train_dataset = dataset['train'] test_dataset = dataset['test'] # Construct a `def make_map_fn(split)` for the corresponding datasets. # ... train_dataset = train_dataset.map(function=make_map_fn('train'), with_indices=True) test_dataset = test_dataset.map(function=make_map_fn('test'), with_indices=True) local_dir = args.local_dir hdfs_dir = args.hdfs_dir train_dataset.to_parquet(os.path.join(local_dir, 'train.parquet')) test_dataset.to_parquet(os.path.join(local_dir, 'test.parquet')) makedirs(hdfs_dir) copy(src=local_dir, dst=hdfs_dir) 2. 用户需要自行实现 ``make_map_fn()`` 函数(以及 ``extract_solution``)来支持不同的数据集或任务。 我们已经实现了 GSM8k、MATH、Hellaswag 和 Full_hh_rlhf 数据集的数据预处理。下面以 GSM8k 数据集为例: **GSM8K** 在 ``make_map_fn`` 中,每个数据字段应包含以下 5 个字段: 1. ``data_source``:数据集的名称。用于在 ``RewardModel`` 中索引相应的奖励函数。 2. ``prompt``:此字段应采用 huggingface chat_template 的格式。``RLHFDataset`` 中的分词器将应用 chat 模板并对 prompt 进行分词。 3. ``ability``:定义任务类别。 4. ``reward_model``:目前,我们仅在评估期间使用 ``ground_truth`` 字段。``ground_truth`` 由 ``extract_solution`` 函数计算。**请注意**,相应的奖励函数的实现应与此提取的 ``ground_truth`` 相符。 5. ``extra_info``:记录当前 prompt 的一些信息。目前不使用。 .. code:: python def extract_solution(solution_str): solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str) # extract the solution after #### assert solution is not None final_solution = solution.group(0) final_solution = final_solution.split('#### ')[1].replace(',', '') return final_solution instruction_following = "Let's think step by step and output the final answer after \"####\"." # add a row to each data item that represents a unique id def make_map_fn(split): def process_fn(example, idx): question = example.pop('question') question = question + ' ' + instruction_following answer = example.pop('answer') solution = extract_solution(answer) data = { "data_source": data_source, "prompt": [{ "role": "user", "content": question }], "ability": "math", "reward_model": { "style": "rule", "ground_truth": solution }, "extra_info": { 'split': split, 'index': idx } } return data return process_fn