[Feature] Domino EP support and training optimizations for InternS1 Pro VL#1528
[Feature] Domino EP support and training optimizations for InternS1 Pro VL#1528tina-wen wants to merge 2 commits intoInternLM:mainfrom
Conversation
1bd80f0 to
fb3ae25
Compare
|
@claude review |
| def _prepare_llm_inputs( | ||
| self, | ||
| seq_ctx: SequenceContext, | ||
| loss_ctx: dict[str, CELossContext] | None = None | ||
| ) -> MoEModelOutputs: | ||
| input_ids = seq_ctx.input_ids | ||
| pixel_values = seq_ctx.pixel_values |
There was a problem hiding this comment.
Claude: Bug — Wrong return type annotation.
_prepare_llm_inputs returns inputs_embeds (a torch.Tensor), but the annotation says MoEModelOutputs. This was carried over from the old forward signature during the extract-method refactor.
| def _prepare_llm_inputs( | |
| self, | |
| seq_ctx: SequenceContext, | |
| loss_ctx: dict[str, CELossContext] | None = None | |
| ) -> MoEModelOutputs: | |
| input_ids = seq_ctx.input_ids | |
| pixel_values = seq_ctx.pixel_values | |
| def _prepare_llm_inputs( | |
| self, | |
| seq_ctx: SequenceContext, | |
| ) -> torch.Tensor: |
| routing_weights_mean_global = local_gating_sum / valid_tokens | ||
|
|
||
| loss = scale_global * (tokens_per_expert_global * routing_weights_mean_global).sum(-1) | ||
| return loss.sum() * balancing_loss_weight | ||
|
|
||
| def cal_tokens_per_expert(self) -> torch.Tensor: |
There was a problem hiding this comment.
Claude: Perf note — maybe_offload_tensor calls wait_async_offload(async_offload_to_cpu(...)) synchronously, so the D2H copy blocks the calling stream. The AsyncOffloadedTensor abstraction exists but is never exploited for overlap.
If the intent is to overlap the D2H copy with the subsequent accumulate_layer_balancing_loss computation on GPU, this function should return the AsyncOffloadedTensor directly and defer the wait to maybe_wait_offload_tensor. Currently the async machinery has no effect — the copy is fully synchronous.
If synchronous offload is intentional (simplicity over performance), consider simplifying to just tensor.detach().cpu() and removing the unused async infrastructure.
| ) | ||
| accumulate_layer_balancing_loss( | ||
| self.layer_balancing_loss, | ||
| layer_idx=int(idx), |
There was a problem hiding this comment.
Claude: Warning — When self.layer_balancing_loss is not None, z-loss computation is silently skipped. If a user configures both layer_balancing_loss_cfg and z_loss_cfg, the z-loss will be ignored without any warning or error. Consider logging a warning during build() or raising if both are configured, so users don't silently lose z-loss regularization.
| mask_list = torch.cat([ctx.mask for ctx in seq_ctx_list], dim=1) | ||
|
|
||
| # Initialize output containers | ||
| output: dict = {} | ||
|
|
There was a problem hiding this comment.
Claude: Nit — mask_list is computed by concatenating masks from all micro-batches but mask_list is a slightly misleading name since it's a tensor, not a list. Consider cat_mask or combined_mask for consistency with the cat_hidden_states / cat_position_ids naming pattern used elsewhere in this method.
| self.layer_balancing_loss, | ||
| balancing_ctx=balancing_ctx, | ||
| num_experts_per_tok=self.config.num_experts_per_tok, | ||
| non_pad_token=router_weights.shape[1], |
There was a problem hiding this comment.
Claude: Bug — non_pad_token gets the wrong value when layer balancing is enabled.
When self.layer_balancing_loss is not None, the _select_non_pad_router_logits call above (which reassigns router_weights to shape (num_layers, non_pad_seq, num_experts)) is skipped. The router_weights variable here still holds the raw output from the last MoE layer with shape (seq_len, num_experts). So router_weights.shape[1] evaluates to num_experts (e.g. 256), not the number of non-padding tokens.
This produces incorrect loss values in single-GPU / non-distributed mode (the distributed path in finalize() ignores non_pad_token).
Fix: compute from the mask directly:
| non_pad_token=router_weights.shape[1], | |
| non_pad_token=int(seq_ctx.mask.sum().item()), |
| SequenceContext( | ||
| input_ids=None, | ||
| cu_seq_lens_q=single_seq_ctx.cu_seq_lens_q, | ||
| cu_seq_lens_k=single_seq_ctx.cu_seq_lens_k, | ||
| max_length_q=single_seq_ctx.max_length_q, | ||
| max_length_k=single_seq_ctx.max_length_k, | ||
| position_ids=single_seq_ctx.position_ids, | ||
| num_padding=single_seq_ctx.num_padding, | ||
| sequence_parallel_mesh=single_seq_ctx.sequence_parallel_mesh, | ||
| inputs_embeds=inputs_embeds, | ||
| rollout_routed_experts=single_seq_ctx.rollout_routed_experts, | ||
| deepstack_visual_embeds=None, | ||
| visual_pos_masks=None | ||
| ) |
There was a problem hiding this comment.
Claude: Warning — Using explicit SequenceContext(...) construction instead of seq_ctx.copy(...) risks silently dropping fields.
The original code used seq_ctx.copy(input_ids=None, inputs_embeds=..., ...) which preserves all fields not explicitly overridden (e.g. device, block_table, image_grid_thw, num_img_tokens, raw_input_ids, shard_start, shard_size). The new code constructs from scratch, so any field not listed falls back to its default — notably device defaults to "cpu".
Consider using copy() instead:
lang_seq_ctx_list.append(
single_seq_ctx.copy(
input_ids=None,
inputs_embeds=inputs_embeds,
deepstack_visual_embeds=None,
visual_pos_masks=None,
)
)Same applies to the elif isinstance(seq_ctx, SequenceContext) branch below (line 225).
| ) | ||
| return inputs_embeds | ||
|
|
||
| def forward(self, seq_ctx: SequenceContext | list[SequenceContext], loss_ctx: dict[str, CELossContext] | None = None) -> MoEModelOutputs: |
There was a problem hiding this comment.
Claude: Nit — This line is ~126 characters, exceeding the 119-character max line length configured in pyproject.toml.
Consider splitting the signature:
def forward(
self,
seq_ctx: SequenceContext | list[SequenceContext],
loss_ctx: dict[str, CELossContext] | None = None,
) -> MoEModelOutputs:|
Claude: ## Summary Adds Domino EP support and layer-wise MoE balance loss computation for InternS1 Pro VL training. The layer-wise approach splits expert balance loss per-layer to reduce peak memory from accumulating all router outputs, with async CPU offloading for router tensors. IssuesCritical
Warning
Nit
VerdictREQUEST_CHANGES |
Description
This PR optimizes InternS1 Pro VL model training with three key changes:
Results: Performance ↑, Memory ↓, Accuracy unchanged