添加使用 FSDP 后端efois模型
最后更新: 02/09/2025.
模型
原则上,我们的 FSDP 后端支持任何 HF 模型,并且我们可以使用 third_party/vllm 下的 hf_weight_loader.py 将 actor 模型权重与 vLLM 同步。
然而,hf_weight_loader 在同步期间会收集模型的完整 state_dict,这可能会导致 OOM(内存溢出)。我们建议使用 dtensor_weight_loader,它会逐层收集完整的模型参数,以减少峰值内存使用量。我们已经在 third_party/vllm 下的 dtensor_weight_loader.py 中为以下模型支持 dtensor 权重加载器:
GPT2LMHeadModelLlamaForCausalLMLLaMAForCausalLMMistralForCausalLMInternLMForCausalLMAquilaModelAquilaForCausalLMPhi3ForCausalLMGemmaForCausalLMGemma2ForCausalLMGPTBigCodeForCausalLMStarcoder2ForCausalLMQwen2ForCausalLMDeepseekV2ForCausalLM
要为 vLLM 中支持的模型实现 dtensor_weight_loader,请遵循下方 Gemma 模型的指南:
将 vllm 模型类中的
load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]])复制到dtensor_weight_loaders.py将参数修改为
(actor_weights: Dict, vllm_model: nn.Module)将
self替换为vllm_model在每个
param = params_dict[name]之前添加local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight),并使用local_loaded_weight修改后续的权重加载。将实现的 dtensor 权重加载器注册到
__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__。
- def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+ def gemma_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
- params_dict = dict(self.named_parameters())
+ params_dict = dict(vllm_model.named_parameters())
loaded_params = set()
- for name, loaded_weight in weights:
+ for name, loaded_weight in actor_weights.items():
for (param_name, shard_name, shard_id) in stacked_params_mapping:
if shard_name not in name:
continue
name = name.replace(shard_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
+ local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
param = params_dict[name]
weight_loader = param.weight_loader
- weight_loader(param, loaded_weight, shard_id)
+ weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id)
break
else:
# lm_head is not used in vllm as it is tied with embed_token.
# To prevent errors, skip loading lm_head.weight.
if "lm_head.weight" in name:
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
+ local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
- weight_loader(param, loaded_weight)
+ weight_loader(param, local_loaded_weight.to(dtype=param.dtype))
loaded_params.add(name)
unloaded_params = params_dict.keys() - loaded_params
if unloaded_params:
raise RuntimeError(
"Some weights are not initialized from checkpoints: "
f"{unloaded_params}")