Skip to content

[Feature] Domino EP support and training optimizations for InternS1 Pro VL#1528

Open
tina-wen wants to merge 2 commits intoInternLM:mainfrom
tina-wen:split_bal_loss
Open

[Feature] Domino EP support and training optimizations for InternS1 Pro VL#1528
tina-wen wants to merge 2 commits intoInternLM:mainfrom
tina-wen:split_bal_loss

Conversation

@tina-wen
Copy link
Copy Markdown
Contributor

@tina-wen tina-wen commented Mar 3, 2026

Description

This PR optimizes InternS1 Pro VL model training with three key changes:

  • Domino EP: Add support for domino_ep parallelism
  • Layer-wise MoE loss: Split expert balance loss computation to reduce memory

Results: Performance ↑, Memory ↓, Accuracy unchanged

@tina-wen tina-wen force-pushed the split_bal_loss branch 2 times, most recently from 1bd80f0 to fb3ae25 Compare April 1, 2026 14:17
@HAOCHENYE
Copy link
Copy Markdown
Collaborator

@claude review

Comment on lines +141 to 146
def _prepare_llm_inputs(
self,
seq_ctx: SequenceContext,
loss_ctx: dict[str, CELossContext] | None = None
) -> MoEModelOutputs:
input_ids = seq_ctx.input_ids
pixel_values = seq_ctx.pixel_values
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Bug — Wrong return type annotation.

_prepare_llm_inputs returns inputs_embeds (a torch.Tensor), but the annotation says MoEModelOutputs. This was carried over from the old forward signature during the extract-method refactor.

Suggested change
def _prepare_llm_inputs(
self,
seq_ctx: SequenceContext,
loss_ctx: dict[str, CELossContext] | None = None
) -> MoEModelOutputs:
input_ids = seq_ctx.input_ids
pixel_values = seq_ctx.pixel_values
def _prepare_llm_inputs(
self,
seq_ctx: SequenceContext,
) -> torch.Tensor:

Comment on lines +220 to +225
routing_weights_mean_global = local_gating_sum / valid_tokens

loss = scale_global * (tokens_per_expert_global * routing_weights_mean_global).sum(-1)
return loss.sum() * balancing_loss_weight

def cal_tokens_per_expert(self) -> torch.Tensor:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Perf notemaybe_offload_tensor calls wait_async_offload(async_offload_to_cpu(...)) synchronously, so the D2H copy blocks the calling stream. The AsyncOffloadedTensor abstraction exists but is never exploited for overlap.

If the intent is to overlap the D2H copy with the subsequent accumulate_layer_balancing_loss computation on GPU, this function should return the AsyncOffloadedTensor directly and defer the wait to maybe_wait_offload_tensor. Currently the async machinery has no effect — the copy is fully synchronous.

If synchronous offload is intentional (simplicity over performance), consider simplifying to just tensor.detach().cpu() and removing the unused async infrastructure.

)
accumulate_layer_balancing_loss(
self.layer_balancing_loss,
layer_idx=int(idx),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Warning — When self.layer_balancing_loss is not None, z-loss computation is silently skipped. If a user configures both layer_balancing_loss_cfg and z_loss_cfg, the z-loss will be ignored without any warning or error. Consider logging a warning during build() or raising if both are configured, so users don't silently lose z-loss regularization.

Comment on lines +460 to 464
mask_list = torch.cat([ctx.mask for ctx in seq_ctx_list], dim=1)

# Initialize output containers
output: dict = {}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Nitmask_list is computed by concatenating masks from all micro-batches but mask_list is a slightly misleading name since it's a tensor, not a list. Consider cat_mask or combined_mask for consistency with the cat_hidden_states / cat_position_ids naming pattern used elsewhere in this method.

self.layer_balancing_loss,
balancing_ctx=balancing_ctx,
num_experts_per_tok=self.config.num_experts_per_tok,
non_pad_token=router_weights.shape[1],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Bugnon_pad_token gets the wrong value when layer balancing is enabled.

When self.layer_balancing_loss is not None, the _select_non_pad_router_logits call above (which reassigns router_weights to shape (num_layers, non_pad_seq, num_experts)) is skipped. The router_weights variable here still holds the raw output from the last MoE layer with shape (seq_len, num_experts). So router_weights.shape[1] evaluates to num_experts (e.g. 256), not the number of non-padding tokens.

This produces incorrect loss values in single-GPU / non-distributed mode (the distributed path in finalize() ignores non_pad_token).

Fix: compute from the mask directly:

Suggested change
non_pad_token=router_weights.shape[1],
non_pad_token=int(seq_ctx.mask.sum().item()),

Comment on lines +207 to +220
SequenceContext(
input_ids=None,
cu_seq_lens_q=single_seq_ctx.cu_seq_lens_q,
cu_seq_lens_k=single_seq_ctx.cu_seq_lens_k,
max_length_q=single_seq_ctx.max_length_q,
max_length_k=single_seq_ctx.max_length_k,
position_ids=single_seq_ctx.position_ids,
num_padding=single_seq_ctx.num_padding,
sequence_parallel_mesh=single_seq_ctx.sequence_parallel_mesh,
inputs_embeds=inputs_embeds,
rollout_routed_experts=single_seq_ctx.rollout_routed_experts,
deepstack_visual_embeds=None,
visual_pos_masks=None
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Warning — Using explicit SequenceContext(...) construction instead of seq_ctx.copy(...) risks silently dropping fields.

The original code used seq_ctx.copy(input_ids=None, inputs_embeds=..., ...) which preserves all fields not explicitly overridden (e.g. device, block_table, image_grid_thw, num_img_tokens, raw_input_ids, shard_start, shard_size). The new code constructs from scratch, so any field not listed falls back to its default — notably device defaults to "cpu".

Consider using copy() instead:

lang_seq_ctx_list.append(
    single_seq_ctx.copy(
        input_ids=None,
        inputs_embeds=inputs_embeds,
        deepstack_visual_embeds=None,
        visual_pos_masks=None,
    )
)

Same applies to the elif isinstance(seq_ctx, SequenceContext) branch below (line 225).

)
return inputs_embeds

def forward(self, seq_ctx: SequenceContext | list[SequenceContext], loss_ctx: dict[str, CELossContext] | None = None) -> MoEModelOutputs:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Nit — This line is ~126 characters, exceeding the 119-character max line length configured in pyproject.toml.

Consider splitting the signature:

def forward(
    self,
    seq_ctx: SequenceContext | list[SequenceContext],
    loss_ctx: dict[str, CELossContext] | None = None,
) -> MoEModelOutputs:

@claude
Copy link
Copy Markdown
Contributor

claude bot commented Apr 13, 2026

Claude: ## Summary

Adds Domino EP support and layer-wise MoE balance loss computation for InternS1 Pro VL training. The layer-wise approach splits expert balance loss per-layer to reduce peak memory from accumulating all router outputs, with async CPU offloading for router tensors.

Issues

Critical

  • moe.py:825non_pad_token=router_weights.shape[1] uses the wrong tensor when layer balancing is enabled. router_weights holds the last MoE layer's raw output (seq, num_experts), so shape[1] = num_experts, not the non-padding token count. Produces incorrect loss in non-distributed mode. Fix: non_pad_token=int(seq_ctx.mask.sum().item()).
  • modeling_qwen3_vl.py:141-146_prepare_llm_inputs return type annotated as MoEModelOutputs but actually returns torch.Tensor.

Warning

  • modeling_qwen3_vl.py:207-236 — Replaced seq_ctx.copy(...) with explicit SequenceContext(...) construction, silently dropping fields like device (defaults to "cpu"), block_table, image_grid_thw, etc. Use copy() to preserve all original fields.
  • moe.py:753 — Z-loss is silently disabled when layer_balancing_loss is enabled. If both are configured, users get no warning that z-loss is being skipped.
  • layer_moe_loss.py:220-225maybe_offload_tensor synchronously waits on the async D2H copy, so the async machinery provides no overlap benefit. Either defer the wait to maybe_wait_offload_tensor, or simplify to plain .cpu().

Nit

  • modeling_qwen3_vl.py:200forward signature exceeds 119-char line limit.
  • moe.py:460-464mask_list is a tensor, not a list; naming is inconsistent with cat_hidden_states / cat_position_ids convention.

Verdict

REQUEST_CHANGES

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants