Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions src/lerobot/policies/pi0/modeling_pi0.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,13 +443,13 @@ def embed_image(self, image: torch.Tensor):
if image.dtype != torch.float32:
image = image.to(torch.float32)
image_outputs = self.paligemma.model.get_image_features(image)
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
features = image_outputs.pooler_output
if features.dtype != out_dtype:
features = features.to(out_dtype)
return features

def embed_language_tokens(self, tokens: torch.Tensor):
return self.paligemma.model.language_model.embed_tokens(tokens)
return self.paligemma.model.language_model.get_input_embeddings()(tokens)

def forward(
self,
Expand Down Expand Up @@ -665,8 +665,7 @@ def image_embed_func(img):
# Process language tokens
def lang_embed_func(lang_tokens):
lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
lang_emb_dim = lang_emb.shape[-1]
return lang_emb * math.sqrt(lang_emb_dim)
return lang_emb

lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens)
embs.append(lang_emb)
Expand Down
12 changes: 4 additions & 8 deletions src/lerobot/policies/pi0_fast/modeling_pi0_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,13 +258,13 @@ def embed_image(self, image: torch.Tensor):
if image.dtype != torch.float32:
image = image.to(torch.float32)
image_outputs = self.paligemma.model.get_image_features(image)
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
features = image_outputs.pooler_output
if features.dtype != out_dtype:
features = features.to(out_dtype)
return features

def embed_language_tokens(self, tokens: torch.Tensor):
return self.paligemma.model.language_model.embed_tokens(tokens)
return self.paligemma.model.language_model.get_input_embeddings()(tokens)

def forward(
self,
Expand Down Expand Up @@ -414,8 +414,7 @@ def image_embed_func(img):
# Process language instruction tokens
def lang_embed_func(tokens):
lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens)
lang_emb_dim = lang_emb.shape[-1]
return lang_emb * math.sqrt(lang_emb_dim)
return lang_emb
Comment on lines 414 to +417
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After removing the math.sqrt(...) scaling, this block no longer uses math; since the module is now unused in this file, the import math at the top should be removed to satisfy linters (ruff/flake8) and avoid dead imports.

Copilot uses AI. Check for mistakes.

lang_emb = self._apply_checkpoint(lang_embed_func, tokens)
embs.append(lang_emb)
Expand All @@ -429,8 +428,7 @@ def lang_embed_func(tokens):

def fast_action_embed_func(fast_action_tokens):
fast_emb = self.paligemma_with_expert.embed_language_tokens(fast_action_tokens)
fast_emb_dim = fast_emb.shape[-1]
return fast_emb * math.sqrt(fast_emb_dim)
return fast_emb

fast_action_emb = self._apply_checkpoint(fast_action_embed_func, fast_action_tokens)
embs.append(fast_action_emb)
Expand Down Expand Up @@ -663,7 +661,6 @@ def sample_actions_fast(
if t < max_decoding_steps - 1:
# embed the newly generated token
next_token_emb = self.paligemma_with_expert.embed_language_tokens(next_token)
next_token_emb = next_token_emb * math.sqrt(next_token_emb.shape[-1])
if prefix_embs.dtype == torch.bfloat16:
next_token_emb = next_token_emb.to(dtype=torch.bfloat16)

Expand Down Expand Up @@ -768,7 +765,6 @@ def sample_actions_fast_kv_cache(
# Embed the single previous token
# We use embed_language_tokens directly to avoid overhead of full prefix embedding
next_token_emb = self.paligemma_with_expert.embed_language_tokens(next_token)
next_token_emb = next_token_emb * math.sqrt(next_token_emb.shape[-1])
if prefix_embs.dtype == torch.bfloat16:
next_token_emb = next_token_emb.to(dtype=torch.bfloat16)

Expand Down
Loading