diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index aebf329645..f2fa6d6963 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -443,13 +443,13 @@ def embed_image(self, image: torch.Tensor): if image.dtype != torch.float32: image = image.to(torch.float32) image_outputs = self.paligemma.model.get_image_features(image) - features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5 + features = image_outputs.pooler_output if features.dtype != out_dtype: features = features.to(out_dtype) return features def embed_language_tokens(self, tokens: torch.Tensor): - return self.paligemma.model.language_model.embed_tokens(tokens) + return self.paligemma.model.language_model.get_input_embeddings()(tokens) def forward( self, @@ -665,8 +665,7 @@ def image_embed_func(img): # Process language tokens def lang_embed_func(lang_tokens): lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens) - lang_emb_dim = lang_emb.shape[-1] - return lang_emb * math.sqrt(lang_emb_dim) + return lang_emb lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens) embs.append(lang_emb) diff --git a/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py b/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py index 1bcf9794c1..6864edb501 100644 --- a/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py +++ b/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py @@ -258,13 +258,13 @@ def embed_image(self, image: torch.Tensor): if image.dtype != torch.float32: image = image.to(torch.float32) image_outputs = self.paligemma.model.get_image_features(image) - features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5 + features = image_outputs.pooler_output if features.dtype != out_dtype: features = features.to(out_dtype) return features def embed_language_tokens(self, tokens: torch.Tensor): - return self.paligemma.model.language_model.embed_tokens(tokens) + return self.paligemma.model.language_model.get_input_embeddings()(tokens) def forward( self, @@ -414,8 +414,7 @@ def image_embed_func(img): # Process language instruction tokens def lang_embed_func(tokens): lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens) - lang_emb_dim = lang_emb.shape[-1] - return lang_emb * math.sqrt(lang_emb_dim) + return lang_emb lang_emb = self._apply_checkpoint(lang_embed_func, tokens) embs.append(lang_emb) @@ -429,8 +428,7 @@ def lang_embed_func(tokens): def fast_action_embed_func(fast_action_tokens): fast_emb = self.paligemma_with_expert.embed_language_tokens(fast_action_tokens) - fast_emb_dim = fast_emb.shape[-1] - return fast_emb * math.sqrt(fast_emb_dim) + return fast_emb fast_action_emb = self._apply_checkpoint(fast_action_embed_func, fast_action_tokens) embs.append(fast_action_emb) @@ -663,7 +661,6 @@ def sample_actions_fast( if t < max_decoding_steps - 1: # embed the newly generated token next_token_emb = self.paligemma_with_expert.embed_language_tokens(next_token) - next_token_emb = next_token_emb * math.sqrt(next_token_emb.shape[-1]) if prefix_embs.dtype == torch.bfloat16: next_token_emb = next_token_emb.to(dtype=torch.bfloat16) @@ -768,7 +765,6 @@ def sample_actions_fast_kv_cache( # Embed the single previous token # We use embed_language_tokens directly to avoid overhead of full prefix embedding next_token_emb = self.paligemma_with_expert.embed_language_tokens(next_token) - next_token_emb = next_token_emb * math.sqrt(next_token_emb.shape[-1]) if prefix_embs.dtype == torch.bfloat16: next_token_emb = next_token_emb.to(dtype=torch.bfloat16)