Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions src/lerobot/policies/pi0/modeling_pi0.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,13 +443,13 @@ def embed_image(self, image: torch.Tensor):
if image.dtype != torch.float32:
image = image.to(torch.float32)
image_outputs = self.paligemma.model.get_image_features(image)
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
features = image_outputs.pooler_output
if features.dtype != out_dtype:
features = features.to(out_dtype)
return features

def embed_language_tokens(self, tokens: torch.Tensor):
return self.paligemma.model.language_model.embed_tokens(tokens)
return self.paligemma.model.language_model.get_inputs_embeddings()(tokens)

def forward(
self,
Expand Down Expand Up @@ -665,8 +665,7 @@ def image_embed_func(img):
# Process language tokens
def lang_embed_func(lang_tokens):
lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
lang_emb_dim = lang_emb.shape[-1]
return lang_emb * math.sqrt(lang_emb_dim)
return lang_emb

lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens)
embs.append(lang_emb)
Expand Down
12 changes: 4 additions & 8 deletions src/lerobot/policies/pi0_fast/modeling_pi0_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,13 +258,13 @@ def embed_image(self, image: torch.Tensor):
if image.dtype != torch.float32:
image = image.to(torch.float32)
image_outputs = self.paligemma.model.get_image_features(image)
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
features = image_outputs.pooler_output
if features.dtype != out_dtype:
features = features.to(out_dtype)
return features

def embed_language_tokens(self, tokens: torch.Tensor):
return self.paligemma.model.language_model.embed_tokens(tokens)
return self.paligemma.model.language_model.get_inputs_embeddings()(tokens)

def forward(
self,
Expand Down Expand Up @@ -414,8 +414,7 @@ def image_embed_func(img):
# Process language instruction tokens
def lang_embed_func(tokens):
lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens)
lang_emb_dim = lang_emb.shape[-1]
return lang_emb * math.sqrt(lang_emb_dim)
return lang_emb
Comment on lines 414 to +417
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After removing the math.sqrt(...) scaling, this block no longer uses math; since the module is now unused in this file, the import math at the top should be removed to satisfy linters (ruff/flake8) and avoid dead imports.

Copilot uses AI. Check for mistakes.

lang_emb = self._apply_checkpoint(lang_embed_func, tokens)
embs.append(lang_emb)
Expand All @@ -429,8 +428,7 @@ def lang_embed_func(tokens):

def fast_action_embed_func(fast_action_tokens):
fast_emb = self.paligemma_with_expert.embed_language_tokens(fast_action_tokens)
fast_emb_dim = fast_emb.shape[-1]
return fast_emb * math.sqrt(fast_emb_dim)
return fast_emb

fast_action_emb = self._apply_checkpoint(fast_action_embed_func, fast_action_tokens)
embs.append(fast_action_emb)
Expand Down Expand Up @@ -663,7 +661,6 @@ def sample_actions_fast(
if t < max_decoding_steps - 1:
# embed the newly generated token
next_token_emb = self.paligemma_with_expert.embed_language_tokens(next_token)
next_token_emb = next_token_emb * math.sqrt(next_token_emb.shape[-1])
if prefix_embs.dtype == torch.bfloat16:
next_token_emb = next_token_emb.to(dtype=torch.bfloat16)

Expand Down Expand Up @@ -768,7 +765,6 @@ def sample_actions_fast_kv_cache(
# Embed the single previous token
# We use embed_language_tokens directly to avoid overhead of full prefix embedding
next_token_emb = self.paligemma_with_expert.embed_language_tokens(next_token)
next_token_emb = next_token_emb * math.sqrt(next_token_emb.shape[-1])
if prefix_embs.dtype == torch.bfloat16:
next_token_emb = next_token_emb.to(dtype=torch.bfloat16)

Expand Down
Loading