From a3c43480f58e53b0fd6c582bafc905b5ab982173 Mon Sep 17 00:00:00 2001 From: Yichi Zhang <164144477@qq.com> Date: Sat, 22 Nov 2025 20:16:52 +0800 Subject: [PATCH 1/3] Refactor model checkpoint saving logic Only save once at saving step when accumulation_step != 1 --- vla-scripts/finetune.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/vla-scripts/finetune.py b/vla-scripts/finetune.py index 03263c1..24e01d3 100644 --- a/vla-scripts/finetune.py +++ b/vla-scripts/finetune.py @@ -1081,21 +1081,21 @@ def rename_state_dict_keys(state_dict, replace_map): optimizer.zero_grad() progress.update() - # Save model checkpoint: either keep latest checkpoint only or all checkpoints - if gradient_step_idx > 0 and log_step % cfg.save_freq == 0: - save_training_checkpoint( - cfg=cfg, - run_dir=run_dir, - log_step=log_step, - vla=vla, - processor=processor, - proprio_projector=proprio_projector if cfg.use_proprio else None, - noisy_action_projector=None, - action_head=action_head, - train_dataset=train_dataset, - distributed_state=distributed_state, - new_state_dict=RAW_STATE_DICT, - ) + # Save model checkpoint: either keep latest checkpoint only or all checkpoints + if gradient_step_idx > 0 and log_step % cfg.save_freq == 0: + save_training_checkpoint( + cfg=cfg, + run_dir=run_dir, + log_step=log_step, + vla=vla, + processor=processor, + proprio_projector=proprio_projector if cfg.use_proprio else None, + noisy_action_projector=None, + action_head=action_head, + train_dataset=train_dataset, + distributed_state=distributed_state, + new_state_dict=RAW_STATE_DICT, + ) # Test model on validation set if cfg.use_val_set and log_step > 0 and log_step % cfg.val_freq == 0: From b45cfda6a82bab9bc15b82ab9d305943bb64d12f Mon Sep 17 00:00:00 2001 From: Yichi Zhang <164144477@qq.com> Date: Sat, 22 Nov 2025 20:21:01 +0800 Subject: [PATCH 2/3] Remove bos and eos tokens from hidden states when test otherwise the action_hidden_states will contains a text token, and the task_latten_states will contais a bos token --- prismatic/extern/hf/modeling_prismatic.py | 1 + 1 file changed, 1 insertion(+) diff --git a/prismatic/extern/hf/modeling_prismatic.py b/prismatic/extern/hf/modeling_prismatic.py index 945d03e..87b9c51 100644 --- a/prismatic/extern/hf/modeling_prismatic.py +++ b/prismatic/extern/hf/modeling_prismatic.py @@ -848,6 +848,7 @@ def _regression_or_discrete_prediction( multi_layer_hidden_states = [] for item in language_model_output.hidden_states[0:]: + item = [:, 1:-1, :] # remove bos and eos token first # last_hidden_states = output.hidden_states[-1] # (B, seq_len, D) # Get hidden states for text portion of prompt+response (after the vision patches) text_hidden_states = item From 1bb4ef56dab1715b72e745aa64a81d86c56da098 Mon Sep 17 00:00:00 2001 From: Yichi Zhang <164144477@qq.com> Date: Sat, 22 Nov 2025 20:30:53 +0800 Subject: [PATCH 3/3] Remove bos and eos tokens from hidden states when training --- vla-scripts/finetune.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vla-scripts/finetune.py b/vla-scripts/finetune.py index 24e01d3..60c5517 100644 --- a/vla-scripts/finetune.py +++ b/vla-scripts/finetune.py @@ -396,9 +396,10 @@ def run_forward_pass( multi_layer_hidden_states = [] for item in output.hidden_states[0:]: + item = item[:, 1:-1, :] # remove bos token and eos token # last_hidden_states = output.hidden_states[-1] # (B, seq_len, D) # Get hidden states for text portion of prompt+response (after the vision patches) - text_hidden_states = item[:, num_patches:-1] + text_hidden_states = item[:, num_patches:, :] # Get hidden states for action portion of response batch_size = batch["input_ids"].shape[0] # actions_hidden_states = text_hidden_states[:, -1, :].reshape(batch_size, 1, -1).to(torch.bfloat16)