PyTorch FSDP 后端 ====================== 最后更新时间:2025 年 12 月 2 日。 我们通过实现 actor、critic、reference、rollout 和 reward 模型的各种 worker 来支持 PyTorch FSDP 后端。我们还在 `fsdp_vllm.py `_ 中实现了 ``FSDPVLLMShardingManager``,用于在 FSDP 和 vLLM 之间重分片权重。 **优点** - 可轻松支持各种模型。 - 用户只需实现相应的 ``dtensor_weight_loader`` 来实现 FSDP 和 vLLM 之间的权重同步。而对于 ``hf_weight_loader``,用户可以直接应用 HF 和 vLLM 都支持的任何模型,无需任何代码更改。 - 易于组织每个模型的正向和反向计算。 **缺点** - 对于大型模型(例如 Llama 70B 和 405B)可扩展性较差。 - actor 和 rollout 之间的重分片开销可能大于 Megatron-LM 后端。 鉴于其简洁性,我们推荐使用 FSDP 后端进行算法研究和原型设计。 FSDP Workers -------------- ActorRolloutRefWorker ^^^^^^^^^^^^^^^^^^^^^ Actor/Rollout HybridEngine '''''''''''''''''''''''''' 1. HybridEngine、Actor 和 Rollout 初始化 API。 .. code:: python @register(dispatch_mode=Dispatch.ONE_TO_ALL) def init_model(self): ``ONE_TO_ALL``:当在驱动进程中调用 ``init_model`` 函数时,每个 worker(在 GPU 上)将执行以下模型初始化过程。 HybridEngine、Actor 和 Rollout 的初始化细节如下: 1. ``DataParallelPPOActor`` 实现当模型使用 FSDP 构建时,简单的 PPO 计算逻辑,包括计算 log prob、模型更新。 2. ``vLLMRollout`` 支持使用 vLLM 进行生成。我们修改了 vLLM Engine,使其在 SPMD 下执行,以适应我们的 ``WorkerGroup`` 设计。 3. ``FSDPVLLMShardingManager`` 是一个上下文管理器,用于在 actor 和 rollout 之间执行实际的重分片。 有关更多信息,请参阅 `源代码 `_。 1. 生成序列并重新计算 log prob .. code:: python @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def generate_sequences(self, prompts: DataProto): - ``Dispatch.DP_COMPUTE_PROTO``:数据将在 DP 维度上分派和收集。 - 在此函数中,rollout 模型将执行自回归生成,actor 模型将为生成的响应重新计算旧的 log prob。 3. 更新 actor 模型 .. code:: python @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def update_actor(self, data: DataProto): - 使用 PPO 和熵损失更新 actor 模型权重。 ReferenceModel '''''''''''''' 1. Reference model 初始化 reference 模型使用与 actor 模型相同的函数进行初始化,但不初始化 HybridEngine 和 Optimizer。然后,actor 模型也由 ``DataParallelPPOActor`` 包装。 2. 计算 reference log prob .. code:: python @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def compute_ref_log_prob(self, data: DataProto): - 在此函数中,reference 模型将调用 ``DataParallelPPOActor`` 中的计算 log prob 函数来计算 reference log prob。 CriticWorker and RewardWorker ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 1. Model initialization 与 reference 模型非常相似。CriticWorker 将为 Optimizer 执行额外的初始化。 2. Compute Values for CriticWorker .. code:: python @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def compute_values(self, data: DataProto): 3. Update Critic .. code:: python @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def update_critic(self, data: DataProto): 4. Compute Reward .. code:: python @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def compute_rm_score(self, data: DataProto): HybridShard ------------ 我们不支持 FSDP `HybridShard`。要支持此功能,我们可能需要构建一个二维设备网格,并为每个模型测试相应的 ``dtensor_weight_loader`` 和 ``hf_weight_loader``。