-
Notifications
You must be signed in to change notification settings - Fork 4.2k
fix(pi0-fast): don't apply embed scaling #3304
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
zucchini-nlp
wants to merge
2
commits into
huggingface:main
Choose a base branch
from
zucchini-nlp:pio-fast-embed-scaling
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+7
−12
Open
Changes from 1 commit
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -258,13 +258,13 @@ def embed_image(self, image: torch.Tensor): | |
| if image.dtype != torch.float32: | ||
| image = image.to(torch.float32) | ||
| image_outputs = self.paligemma.model.get_image_features(image) | ||
| features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5 | ||
| features = image_outputs.pooler_output | ||
| if features.dtype != out_dtype: | ||
| features = features.to(out_dtype) | ||
| return features | ||
|
|
||
| def embed_language_tokens(self, tokens: torch.Tensor): | ||
| return self.paligemma.model.language_model.embed_tokens(tokens) | ||
| return self.paligemma.model.language_model.get_inputs_embeddings()(tokens) | ||
zucchini-nlp marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| def forward( | ||
| self, | ||
|
|
@@ -414,8 +414,7 @@ def image_embed_func(img): | |
| # Process language instruction tokens | ||
| def lang_embed_func(tokens): | ||
| lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens) | ||
| lang_emb_dim = lang_emb.shape[-1] | ||
| return lang_emb * math.sqrt(lang_emb_dim) | ||
| return lang_emb | ||
|
Comment on lines
414
to
+417
|
||
|
|
||
| lang_emb = self._apply_checkpoint(lang_embed_func, tokens) | ||
| embs.append(lang_emb) | ||
|
|
@@ -429,8 +428,7 @@ def lang_embed_func(tokens): | |
|
|
||
| def fast_action_embed_func(fast_action_tokens): | ||
| fast_emb = self.paligemma_with_expert.embed_language_tokens(fast_action_tokens) | ||
| fast_emb_dim = fast_emb.shape[-1] | ||
| return fast_emb * math.sqrt(fast_emb_dim) | ||
| return fast_emb | ||
|
|
||
| fast_action_emb = self._apply_checkpoint(fast_action_embed_func, fast_action_tokens) | ||
| embs.append(fast_action_emb) | ||
|
|
@@ -663,7 +661,6 @@ def sample_actions_fast( | |
| if t < max_decoding_steps - 1: | ||
| # embed the newly generated token | ||
| next_token_emb = self.paligemma_with_expert.embed_language_tokens(next_token) | ||
| next_token_emb = next_token_emb * math.sqrt(next_token_emb.shape[-1]) | ||
| if prefix_embs.dtype == torch.bfloat16: | ||
| next_token_emb = next_token_emb.to(dtype=torch.bfloat16) | ||
|
|
||
|
|
@@ -768,7 +765,6 @@ def sample_actions_fast_kv_cache( | |
| # Embed the single previous token | ||
| # We use embed_language_tokens directly to avoid overhead of full prefix embedding | ||
| next_token_emb = self.paligemma_with_expert.embed_language_tokens(next_token) | ||
| next_token_emb = next_token_emb * math.sqrt(next_token_emb.shape[-1]) | ||
| if prefix_embs.dtype == torch.bfloat16: | ||
| next_token_emb = next_token_emb.to(dtype=torch.bfloat16) | ||
|
|
||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.