diff --git a/.gitgnore b/.gitignore similarity index 96% rename from .gitgnore rename to .gitignore index 6160ebc..ec162be 100644 --- a/.gitgnore +++ b/.gitignore @@ -144,10 +144,14 @@ dmypy.json # Mac OS .DS_Store + # Caches and Datasets cache/ data/ - +pretrained_models/ # Rollout videos and wandb logs rollouts/ wandb/ +outputs/ +experiments/logs/ +evaluation_results/ \ No newline at end of file diff --git a/LIBERO b/LIBERO new file mode 160000 index 0000000..8f1084e --- /dev/null +++ b/LIBERO @@ -0,0 +1 @@ +Subproject commit 8f1084e3132a39270c3a13ebe37270a43ece2a01 diff --git a/calvin b/calvin new file mode 160000 index 0000000..fa03f01 --- /dev/null +++ b/calvin @@ -0,0 +1 @@ +Subproject commit fa03f01f19c65920e18cf37398a9ce859274af76 diff --git a/eval.sh b/eval.sh new file mode 100644 index 0000000..a15a096 --- /dev/null +++ b/eval.sh @@ -0,0 +1,14 @@ +export HF_HUB_OFFLINE=1 +export TRANSFORMERS_OFFLINE=1 +export HF_DATASETS_OFFLINE=1 +CUDA_VISIBLE_DEVICES=1 python experiments/robot/libero/run_libero_eval.py \ + --use_proprio True \ + --num_images_in_input 2 \ + --use_film False \ + --pretrained_checkpoint outputs/configs+libero_10_no_noops+b16+lr-0.0002+lora-r64+dropout-0.0--image_aug--VLA-Adapter--libero_10_no_noops--2025_10_27_17_22_41--use_3d_True_dim_2048_inject_all--170000_chkpt \ + --task_suite_name libero_10 \ + --use_pro_version True \ + --use_3d True \ + --inject_layers all \ +# outputs/configs+libero_10_no_noops+b16+lr-0.0002+lora-r64+dropout-0.0--image_aug--VLA-Adapter--libero_10_no_noops--1759126170--160000_chkpt \ + # > eval_logs/Spatial--chkpt.log 2>&1 & \ No newline at end of file diff --git a/eval2.sh b/eval2.sh new file mode 100644 index 0000000..8408a59 --- /dev/null +++ b/eval2.sh @@ -0,0 +1,6 @@ +export HF_HUB_OFFLINE=1 +export TRANSFORMERS_OFFLINE=1 +export HF_DATASETS_OFFLINE=1 +CUDA_VISIBLE_DEVICES=7 python vla-scripts/evaluate_calvin.py \ + --pretrained_checkpoint outputs/CALVIN-ABC-Pro + # --pretrained_checkpoint outputs/configs+calvin_abc_rlds+b16+lr-0.0002+lora-r64+dropout-0.0--image_aug--VLA-Adapter--calvin_abc_rlds----100000_chkpt \ \ No newline at end of file diff --git a/experiments/robot/libero/run_libero_eval.py b/experiments/robot/libero/run_libero_eval.py index bd82b14..88a76f2 100644 --- a/experiments/robot/libero/run_libero_eval.py +++ b/experiments/robot/libero/run_libero_eval.py @@ -12,7 +12,7 @@ from dataclasses import dataclass from enum import Enum from pathlib import Path -from typing import Optional, Union +from typing import Optional, Union, List import draccus import numpy as np @@ -48,7 +48,7 @@ set_seed_everywhere, ) from prismatic.vla.constants import NUM_ACTIONS_CHUNK - +from prismatic.models.pi3_loader import load_pc_model # Define task suite constants class TaskSuite(str, Enum): @@ -128,6 +128,11 @@ class GenerateConfig: use_pro_version: bool = True # encourage to use the pro models we released. phase: str = "Inference" + use_3d: bool = False + dim_3d: int = 2048 + pi3_path: Path = Path("/home/ruihengwang/vla/VLA-Adapter/pretrained_models/pi3_checkpoint") + inject_layers: Optional[int | List[int] | str] = None + def validate_config(cfg: GenerateConfig) -> None: @@ -292,6 +297,7 @@ def run_episode( noisy_action_projector=None, initial_state=None, log_file=None, + pi3_model=None ): """Run a single episode in the environment.""" # Reset environment @@ -342,7 +348,8 @@ def run_episode( proprio_projector=proprio_projector, noisy_action_projector=noisy_action_projector, use_film=cfg.use_film, - use_minivlm=cfg.use_minivlm + use_minivlm=cfg.use_minivlm, + pi3_model=pi3_model ) action_queue.extend(actions) @@ -383,7 +390,8 @@ def run_task( total_episodes=0, total_successes=0, log_file=None, - save_version=None + save_version=None, + pi3_model=None ): """Run evaluation for a single task.""" # Get task @@ -433,6 +441,7 @@ def run_task( noisy_action_projector, initial_state, log_file, + pi3_model ) # Update counters @@ -483,6 +492,10 @@ def eval_libero(cfg: GenerateConfig) -> float: # Initialize model and components model, action_head, proprio_projector, noisy_action_projector, processor = initialize_model(cfg) + if cfg.use_3d: + pi3_model = load_pc_model(cfg.pi3_path) + else: + pi3_model = None # for name, param in model.named_parameters(): # if 'action_queries' in name: @@ -500,6 +513,7 @@ def eval_libero(cfg: GenerateConfig) -> float: num_tasks = task_suite.n_tasks log_message(f"Task suite: {cfg.task_suite_name}", log_file) + log_message(f"Using pretrained checkpoint: {cfg.pretrained_checkpoint}", log_file) # Start evaluation total_episodes, total_successes = 0, 0 @@ -517,7 +531,8 @@ def eval_libero(cfg: GenerateConfig) -> float: total_episodes, total_successes, log_file, - cfg.save_version + cfg.save_version, + pi3_model ) # Calculate final success rate diff --git a/experiments/robot/openvla_utils.py b/experiments/robot/openvla_utils.py index 03cb7c5..6278c2d 100644 --- a/experiments/robot/openvla_utils.py +++ b/experiments/robot/openvla_utils.py @@ -32,7 +32,7 @@ ACTION_PROPRIO_NORMALIZATION_TYPE, ) from prismatic.vla.datasets.rlds.utils.data_utils import NormalizationType - +from prismatic.models.pi3_loader import load_pc_model # Initialize important constants DATE = time.strftime("%Y_%m_%d") DATE_TIME = time.strftime("%Y_%m_%d-%H_%M_%S") @@ -507,6 +507,9 @@ def get_action_head(cfg: Any, llm_dim: int) -> Union[L1RegressionActionHead]: hidden_dim=llm_dim, action_dim=ACTION_DIM, use_pro_version=cfg.use_pro_version, + use_3d=cfg.use_3d, + dim_3d=cfg.dim_3d, + inject_layers=cfg.inject_layers, ) else: @@ -745,6 +748,8 @@ def get_vla_action( noisy_action_projector: Optional[torch.nn.Module] = None, use_film: bool = False, use_minivlm: bool = False, + use_3d_model: bool = False, + pi3_model: Optional[torch.nn.Module] = None ) -> List[np.ndarray]: """ Generate action predictions with the VLA policy. @@ -764,6 +769,11 @@ def get_vla_action( List[np.ndarray]: Predicted actions """ with torch.inference_mode(): + if use_3d_model: + assert pi3_model is not None + pi3_model = pi3_model.to(DEVICE).to(torch.bfloat16) + + # Collect all input images all_images = [obs["full_image"]] @@ -795,6 +805,24 @@ def get_vla_action( all_wrist_pixel_values = [wrist_inputs["pixel_values"] for wrist_inputs in all_wrist_inputs] inputs["pixel_values"] = torch.cat([primary_pixel_values] + all_wrist_pixel_values, dim=1) + if use_3d_model: + img_1, img_2 = inputs["pixel_values"][:, 0:3, :, :].to(DEVICE).to(torch.bfloat16), inputs["pixel_values"][:, 6:9, :, :].to(DEVICE).to(torch.bfloat16) + pi3_num_reg_token = 5 + + img_tensor = torch.stack([img_1, img_2], dim=1) # [B, 2, 3, H, W] where 2 indicates 2 views + B, N, _, H, W = img_tensor.shape + img_tensor = img_tensor.reshape((B*N, _, H, W)) + hidden = pi3_model.encoder(img_tensor, is_training=True) + if isinstance(hidden, dict): + hidden = hidden["x_norm_patchtokens"] + hidden, pos = pi3_model.decode(hidden, N, H, W) + hidden = hidden[:, pi3_num_reg_token:, :] + L_3d, dim_3d = hidden.shape[-2:] + hidden = hidden.reshape(B, -1, L_3d, dim_3d) + hidden = hidden.reshape(B, -1, dim_3d) + else: + hidden = None + # Process proprioception data if used proprio = None if cfg.use_proprio: @@ -819,6 +847,7 @@ def get_vla_action( noisy_action_projector=noisy_action_projector, action_head=action_head, use_film=use_film, + hidden_3d=hidden ) # Extract subset of actions for open loop steps diff --git a/experiments/robot/robot_utils.py b/experiments/robot/robot_utils.py index 61cedba..32c7806 100644 --- a/experiments/robot/robot_utils.py +++ b/experiments/robot/robot_utils.py @@ -107,6 +107,7 @@ def get_action( noisy_action_projector: Optional[torch.nn.Module] = None, use_film: bool = False, use_minivlm: bool = False, + pi3_model: Optional[torch.nn.Module] = None ) -> Union[List[np.ndarray], np.ndarray]: """ Query the model to get action predictions. @@ -140,7 +141,9 @@ def get_action( proprio_projector=proprio_projector, noisy_action_projector=noisy_action_projector, use_film=use_film, - use_minivlm=use_minivlm + use_minivlm=use_minivlm, + use_3d_model=cfg.use_3d, + pi3_model=pi3_model ) else: raise ValueError(f"Unsupported model family: {cfg.model_family}") diff --git a/pretrained_models/configs/modeling_prismatic.py b/pretrained_models/configs/modeling_prismatic.py index 945d03e..24bb0a4 100644 --- a/pretrained_models/configs/modeling_prismatic.py +++ b/pretrained_models/configs/modeling_prismatic.py @@ -428,6 +428,14 @@ def _replace_input_embeddings(self, input_embeddings, all_actions_mask, noisy_ac Returns: Modified input_embeddings tensor """ + """ + * input_embeddings: [B, L_a + L_lang, Dim] + * all_actions_mask: [B, L_a + L_lang] + * noisy_action_features: [B, L_a, Dim] + * 此处其实是替换,我们 L_a + L_lang 这一串我们把 L_a 的部分,用 mask_indicies 索引从哪开始 L_a 这块 + * 我们 action_queries (论文核心设计)是 Embedding(num_tokens, dim) 的 weight + * 这一块是 [B, L_a + L_lang, Dim] 当中 L_a 替换成 action_queries 的 weight,L_lang 不动 + """ # Clone input to avoid modifying the original tensor new_input_embeddings = input_embeddings.clone() @@ -455,6 +463,15 @@ def _replace_input_embeddings(self, input_embeddings, all_actions_mask, noisy_ac def _process_action_masks(self, labels): """Helper to get action masks from labels""" + """ + * IGNORE_INDEX = -100, labels 中从第一个 -100 开始, + * ACTION_TOKEN_BEGIN_IDX = 151386 + * NUM_TOKENS = 64, action 有 64 个 token ,从而 labels 一般是 64 个非 -100 。 + * ACTION_DIM = 7,current_action 是 labels 里 前 6 个,next_actions 是 后 58 个 + * 两个 mask 都是 Boolean。因此 1-48 是 -100, 49 - 54 是 curr_action, 55 - 110 是 next_actions, 后面都是 -100。 + * 因而 all_action_mask 其实就是 [B, L] 这里 每一个 sample 中 64 个是 True,表示第几个 token 是 action 的。 + * action 部分的 64 个就是 True。余下的是 False + """ current_action_mask = get_current_action_mask(labels) next_actions_mask = get_next_actions_mask(labels) all_actions_mask = current_action_mask | next_actions_mask # (B, seq_len) @@ -462,6 +479,10 @@ def _process_action_masks(self, labels): def _process_vision_features(self, pixel_values, language_embeddings=None, use_film=False): """Process vision features with optional FiLM conditioning""" + """ + * 原设置没有 film condition,因此 language 的 feature embedding 不会传入给 vision transformer。 + * [B, 3 * num_images, H, W] --(vision)--> [B, 256 * num_images, D] --(projector)--> [B, 256 * num_images, llm_dim] + """ if use_film: # FiLM: Infuse language inputs into visual features patch_features = self.vision_backbone(pixel_values, language_embeddings) # (bsz, 256 * num_images, D) @@ -473,6 +494,11 @@ def _process_vision_features(self, pixel_values, language_embeddings=None, use_f def _process_proprio_features(self, projected_patch_embeddings, proprio, proprio_projector): """Process proprioceptive features and append to vision features""" + """ + * 将 proprio 投影到 [B, D] 的 vector,然后 [B, 1, D] + * 然后 append 到尾部 + * 实际上没有使用。 + """ if proprio_projector is not None and proprio is not None: # projected_patch_embeddings: (bsz, num_patches * num_images, llm_dim) # proprio: (bsz, proprio_dim) or (propro_dim,) @@ -486,7 +512,13 @@ def _process_proprio_features(self, projected_patch_embeddings, proprio, proprio def _build_multimodal_attention(self, input_embeddings, projected_patch_embeddings, attention_mask): """Build multimodal embeddings and attention mask""" # Update attention mask - + """ + * 这里 input_embedding 中 L_a 的部分已经被替换为 nn.Embedding 的 weight了。 + * 其实就是 input_embed 和 mask 在 length 上和 vision 的 embed 里 cat + * multimodal_embeddings: [B, 1 + L_v + (L_a + L_lang -1), Dim] 注意这个 1 是 token。L_v 被插在了这二者之间了。 + * multimodal_attention_mask: [B, 1 + L_v + (L_a + L_lang -1)]。 + * vision 部分的 mask [B, L_v] 是 全 True 的。 + """ projected_patch_attention_mask = None if attention_mask is not None: projected_patch_attention_mask = torch.full( @@ -511,6 +543,7 @@ def _build_multimodal_attention(self, input_embeddings, projected_patch_embeddin def _build_multimodal_labels(self, labels, projected_patch_embeddings): """Build multimodal labels with IGNORE_INDEX for patch embeddings""" + #* 所有 vision 部分的 index 都标为 -100(非 action 的 label),然后和原来 label [B, 1 + L_v + (L_a + L_lang -1)] 拼接 if labels is not None: projected_patch_labels = torch.full( (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]), @@ -543,6 +576,22 @@ def forward( use_film: bool = False, ) -> Union[Tuple, PrismaticCausalLMOutputWithPast]: """Run a forward pass through the VLM, returning a PrismaticCausalLMOutputWithPast instance.""" + """ + * Debug NOTE: + * input_ids has shape: [B, 120] with dtype: torch.int64 + ^ input_ids: + * attention_mask has shape: [B, 120] with dtype: torch.bool + ^ attention_mask [torch.where(~m)[0].tolist() for m in attention_mask] + ^ [[119], [109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119], [114, 115, 116, 117, 118, 119], ...] + + * pixel_values has shape: [B, 12, 224, 224] with dtype: torch.float32 + * labels has shape: [B, 120] with dtype: torch.int64 + ^ [(r[0].item(), r[-1].item()) if len(r:=torch.where(l!=-100)[0]) else (None,None) for l in labels] + ^ -100 一段 --> 非 -100 --> -100 一段 + ^ [(54, 118), (44, 108), (49, 113), (48, 112), (50, 114), (44, 108), (49, 113), (55, 119)] + + * proprio has shape: [B, 8] with dtype: torch.float32 + """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -595,6 +644,10 @@ def forward( # === Handle Multimodal Forward === elif (input_ids.shape[0] == pixel_values.shape[0]) or (inputs_embeds.shape[0] == pixel_values.shape[0]): + + #! Entered here! + #* input_ids: [B, L_a+L_lang](int64) --(embedding)--> [B, L_a+L_lang, Dim](bfloat16) where 120 is the sequence len. + #* non -100 labels are acion tokens. assert past_key_values is None, "Unexpected key `past_key_values` provided during multimodal forward!" # Get input embeddings (from language model embeddings) @@ -604,6 +657,11 @@ def forward( # Extract action masks all_actions_mask = self._process_action_masks(labels) + #* labels 有 64 个 非 -100 的 id,mask 也就是对应 64 个 位置是 True。这里也就是 labels 非 -100 的位置对应 True,说明是 action token + #* input_embeddings: [B, L_a + L_lang, Dim] + #* all_actions_mask 定位 L_a 起始终止 index。 + #* language_embeddings: [B, L_lang, Dim] + #* projected_patch_embeddings: [B, L_vis, Dim] # Extract the language portion of the input embeddings (i.e. remove the action tokens portion) # print(input_embeddings[~all_actions_mask].size()) @@ -639,6 +697,10 @@ def forward( # Build labels for multimodal sequence if needed multimodal_labels = self._build_multimodal_labels(labels, projected_patch_embeddings) + + #* multimodal_embeddings: [B, 1 + L_vis + (L_a + L_lang -1), Dim] + #* multimodal_attention_mask: [B, 1 + L_vis + (L_a + L_lang -1)] + #* mask 在 L_vis 和 L_a 为 True,余下为 False,这其实是说 Langugae 部分是 Causal 而 action,vis 是 bidirectional。 # Dispatch to language model language_model_output = self.language_model( @@ -817,6 +879,7 @@ def _regression_or_discrete_prediction( action_head=None, proprio=None, proprio_projector=None, + hidden_3d=None ): """Run L1 regression-based continuous action prediction or discrete action tokens prediction.""" @@ -867,7 +930,8 @@ def _regression_or_discrete_prediction( # L1 regression prediction normalized_actions = action_head.predict_action(multi_layer_hidden_states, proprio=proprio, - proprio_projector=proprio_projector) + proprio_projector=proprio_projector, + hidden_3d=hidden_3d) normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM) normalized_actions = normalized_actions.float().cpu().detach().numpy() else: @@ -918,6 +982,7 @@ def predict_action( pixel_values = kwargs["pixel_values"] # [1, 12, 224, 224] attention_mask = kwargs["attention_mask"] # + hidden_3d = kwargs.get("hidden_3d", None) # Create fake labels tensor (needed for action mask) labels = input_ids.clone() @@ -964,6 +1029,7 @@ def predict_action( action_head=action_head, proprio=proprio, # [8] proprio_projector=proprio_projector, + hidden_3d=hidden_3d, ) # Unnormalize predicted actions diff --git a/prismatic/extern/hf/modeling_prismatic.py b/prismatic/extern/hf/modeling_prismatic.py index 945d03e..228c75f 100644 --- a/prismatic/extern/hf/modeling_prismatic.py +++ b/prismatic/extern/hf/modeling_prismatic.py @@ -428,6 +428,14 @@ def _replace_input_embeddings(self, input_embeddings, all_actions_mask, noisy_ac Returns: Modified input_embeddings tensor """ + """ + * input_embeddings: [B, L_a + L_lang, Dim] + * all_actions_mask: [B, L_a + L_lang] + * noisy_action_features: [B, L_a, Dim] + * 此处其实是替换,我们 L_a + L_lang 这一串我们把 L_a 的部分,用 mask_indicies 索引从哪开始 L_a 这块 + * 我们 action_queries (论文核心设计)是 Embedding(num_tokens, dim) 的 weight + * 这一块是 [B, L_a + L_lang, Dim] 当中 L_a 替换成 action_queries 的 weight,L_lang 不动 + """ # Clone input to avoid modifying the original tensor new_input_embeddings = input_embeddings.clone() @@ -455,6 +463,15 @@ def _replace_input_embeddings(self, input_embeddings, all_actions_mask, noisy_ac def _process_action_masks(self, labels): """Helper to get action masks from labels""" + """ + * IGNORE_INDEX = -100, labels 中从第一个 -100 开始, + * ACTION_TOKEN_BEGIN_IDX = 151386 + * NUM_TOKENS = 64, action 有 64 个 token ,从而 labels 一般是 64 个非 -100 。 + * ACTION_DIM = 7,current_action 是 labels 里 前 6 个,next_actions 是 后 58 个 + * 两个 mask 都是 Boolean。因此 1-48 是 -100, 49 - 54 是 curr_action, 55 - 110 是 next_actions, 后面都是 -100。 + * 因而 all_action_mask 其实就是 [B, L] 这里 每一个 sample 中 64 个是 True,表示第几个 token 是 action 的。 + * action 部分的 64 个就是 True。余下的是 False + """ current_action_mask = get_current_action_mask(labels) next_actions_mask = get_next_actions_mask(labels) all_actions_mask = current_action_mask | next_actions_mask # (B, seq_len) @@ -462,6 +479,10 @@ def _process_action_masks(self, labels): def _process_vision_features(self, pixel_values, language_embeddings=None, use_film=False): """Process vision features with optional FiLM conditioning""" + """ + * 原设置没有 film condition,因此 language 的 feature embedding 不会传入给 vision transformer。 + * [B, 3 * num_images, H, W] --(vision)--> [B, 256 * num_images, D] --(projector)--> [B, 256 * num_images, llm_dim] + """ if use_film: # FiLM: Infuse language inputs into visual features patch_features = self.vision_backbone(pixel_values, language_embeddings) # (bsz, 256 * num_images, D) @@ -473,6 +494,11 @@ def _process_vision_features(self, pixel_values, language_embeddings=None, use_f def _process_proprio_features(self, projected_patch_embeddings, proprio, proprio_projector): """Process proprioceptive features and append to vision features""" + """ + * 将 proprio 投影到 [B, D] 的 vector,然后 [B, 1, D] + * 然后 append 到尾部 + * 实际上没有使用。 + """ if proprio_projector is not None and proprio is not None: # projected_patch_embeddings: (bsz, num_patches * num_images, llm_dim) # proprio: (bsz, proprio_dim) or (propro_dim,) @@ -486,7 +512,13 @@ def _process_proprio_features(self, projected_patch_embeddings, proprio, proprio def _build_multimodal_attention(self, input_embeddings, projected_patch_embeddings, attention_mask): """Build multimodal embeddings and attention mask""" # Update attention mask - + """ + * 这里 input_embedding 中 L_a 的部分已经被替换为 nn.Embedding 的 weight了。 + * 其实就是 input_embed 和 mask 在 length 上和 vision 的 embed 里 cat + * multimodal_embeddings: [B, 1 + L_v + (L_a + L_lang -1), Dim] 注意这个 1 是 token。L_v 被插在了这二者之间了。 + * multimodal_attention_mask: [B, 1 + L_v + (L_a + L_lang -1)]。 + * vision 部分的 mask [B, L_v] 是 全 True 的。 + """ projected_patch_attention_mask = None if attention_mask is not None: projected_patch_attention_mask = torch.full( @@ -511,6 +543,7 @@ def _build_multimodal_attention(self, input_embeddings, projected_patch_embeddin def _build_multimodal_labels(self, labels, projected_patch_embeddings): """Build multimodal labels with IGNORE_INDEX for patch embeddings""" + #* 所有 vision 部分的 index 都标为 -100(非 action 的 label),然后和原来 label [B, 1 + L_v + (L_a + L_lang -1)] 拼接 if labels is not None: projected_patch_labels = torch.full( (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]), @@ -543,6 +576,22 @@ def forward( use_film: bool = False, ) -> Union[Tuple, PrismaticCausalLMOutputWithPast]: """Run a forward pass through the VLM, returning a PrismaticCausalLMOutputWithPast instance.""" + """ + * Debug NOTE: + * input_ids has shape: [B, 120] with dtype: torch.int64 + ^ input_ids: + * attention_mask has shape: [B, 120] with dtype: torch.bool + ^ attention_mask [torch.where(~m)[0].tolist() for m in attention_mask] + ^ [[119], [109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119], [114, 115, 116, 117, 118, 119], ...] + + * pixel_values has shape: [B, 12, 224, 224] with dtype: torch.float32 + * labels has shape: [B, 120] with dtype: torch.int64 + ^ [(r[0].item(), r[-1].item()) if len(r:=torch.where(l!=-100)[0]) else (None,None) for l in labels] + ^ -100 一段 --> 非 -100 --> -100 一段 + ^ [(54, 118), (44, 108), (49, 113), (48, 112), (50, 114), (44, 108), (49, 113), (55, 119)] + + * proprio has shape: [B, 8] with dtype: torch.float32 + """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -595,6 +644,10 @@ def forward( # === Handle Multimodal Forward === elif (input_ids.shape[0] == pixel_values.shape[0]) or (inputs_embeds.shape[0] == pixel_values.shape[0]): + + #! Entered here! + #* input_ids: [B, L_a+L_lang](int64) --(embedding)--> [B, L_a+L_lang, Dim](bfloat16) where 120 is the sequence len. + #* non -100 labels are acion tokens. assert past_key_values is None, "Unexpected key `past_key_values` provided during multimodal forward!" # Get input embeddings (from language model embeddings) @@ -604,13 +657,18 @@ def forward( # Extract action masks all_actions_mask = self._process_action_masks(labels) + #* labels 有 64 个 非 -100 的 id,mask 也就是对应 64 个 位置是 True。这里也就是 labels 非 -100 的位置对应 True,说明是 action token + #* input_embeddings: [B, L_a + L_lang, Dim] + #* all_actions_mask 定位 L_a 起始终止 index。 + #* language_embeddings: [B, L_lang, Dim] + #* projected_patch_embeddings: [B, L_vis, Dim] # Extract the language portion of the input embeddings (i.e. remove the action tokens portion) # print(input_embeddings[~all_actions_mask].size()) language_embeddings = input_embeddings[~all_actions_mask].reshape( input_embeddings.shape[0], -1, input_embeddings.shape[2] ) # (B, lang_seq_len, llm_dim) - + # Get visual features projected_patch_embeddings = self._process_vision_features(pixel_values, language_embeddings, use_film) @@ -639,6 +697,10 @@ def forward( # Build labels for multimodal sequence if needed multimodal_labels = self._build_multimodal_labels(labels, projected_patch_embeddings) + + #* multimodal_embeddings: [B, 1 + L_vis + (L_a + L_lang -1), Dim] + #* multimodal_attention_mask: [B, 1 + L_vis + (L_a + L_lang -1)] + #* mask 在 L_vis 和 L_a 为 True,余下为 False,这其实是说 Langugae 部分是 Causal 而 action,vis 是 bidirectional。 # Dispatch to language model language_model_output = self.language_model( @@ -817,6 +879,7 @@ def _regression_or_discrete_prediction( action_head=None, proprio=None, proprio_projector=None, + hidden_3d=None ): """Run L1 regression-based continuous action prediction or discrete action tokens prediction.""" @@ -867,7 +930,8 @@ def _regression_or_discrete_prediction( # L1 regression prediction normalized_actions = action_head.predict_action(multi_layer_hidden_states, proprio=proprio, - proprio_projector=proprio_projector) + proprio_projector=proprio_projector, + hidden_3d=hidden_3d) normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM) normalized_actions = normalized_actions.float().cpu().detach().numpy() else: @@ -918,6 +982,7 @@ def predict_action( pixel_values = kwargs["pixel_values"] # [1, 12, 224, 224] attention_mask = kwargs["attention_mask"] # + hidden_3d = kwargs.get("hidden_3d", None) # Create fake labels tensor (needed for action mask) labels = input_ids.clone() @@ -964,6 +1029,7 @@ def predict_action( action_head=action_head, proprio=proprio, # [8] proprio_projector=proprio_projector, + hidden_3d=hidden_3d, ) # Unnormalize predicted actions diff --git a/prismatic/models/action_heads.py b/prismatic/models/action_heads.py index 6719c96..9407bac 100644 --- a/prismatic/models/action_heads.py +++ b/prismatic/models/action_heads.py @@ -27,26 +27,51 @@ def __init__( action_dim=7, num_task_tokens=512, use_pro_version=False, + use_3d=False, + dim_3d=None, + inject_layers=None ): super().__init__() + self.use_3d = use_3d self.num_task_tokens = num_task_tokens self.action_dim = action_dim self.hidden_dim = hidden_dim - self.model = MLPResNet( - num_blocks=24, - input_dim=input_dim*ACTION_DIM, - hidden_dim=hidden_dim, - output_dim=action_dim, - use_pro_version=use_pro_version - ) + if not self.use_3d: + self.model = MLPResNet( + num_blocks=24, + input_dim=input_dim*ACTION_DIM, + hidden_dim=hidden_dim, + output_dim=action_dim, + use_pro_version=use_pro_version + ) + else: + assert dim_3d is not None, "dim_3d must be specified when use_3d is True!" + self.model = MLPResNetw3d( + num_blocks=24, + input_dim=input_dim*ACTION_DIM, + hidden_dim=hidden_dim, + output_dim=action_dim, + use_pro_version=use_pro_version, + feat_3d_dim=dim_3d, + inject_layers=inject_layers + ) def predict_action( self, actions_hidden_states, proprio=None, proprio_projector=None, - phase="Inference" + phase="Inference", + **kwargs ): + """ + * action_hidden_states: [B, Hidden, L_v + L_a, Dim] + * proprio_hidden_states: + * proprio_projector: [B, P_dim] --> [B, 1, Dim] + * 输出时:task_hidden_states: [B, Hidden, L_v, Dim], action_hidden_states: [B, Hidden, L_a, Dim] + * cond_actions_hidden_states: [B, A_dim * A_chunk, Dim] --(reshape)-- [B, A_chunk, A_dim * Dim] + * 这 rearranged_actions_hidden_states 是 Learnable PE + """ batch_size = actions_hidden_states.shape[0] device = actions_hidden_states.device @@ -71,13 +96,23 @@ def predict_action( random_perturbations = learnable_random_perturbations(seq_len, dim, device=rearranged_actions_hidden_states.device, dtype=rearranged_actions_hidden_states.dtype) rearranged_actions_hidden_states = (rearranged_actions_hidden_states + random_perturbations) # (1, seq_len, dim) print("-----------------") - - action = self.model( - rearranged_actions_hidden_states, - h_a=actions_hidden_states, - p=proprio_features, - h_t=task_hidden_states - ) + if not self.use_3d: + action = self.model( + rearranged_actions_hidden_states, + h_a=actions_hidden_states, + p=proprio_features, + h_t=task_hidden_states + ) + else: + h_3d = kwargs.get("hidden_3d", None) + assert h_3d is not None, "h_3d must be passed when use_3d is True!" + action = self.model( + rearranged_actions_hidden_states, + h_a=actions_hidden_states, + p=proprio_features, + h_t=task_hidden_states, + h_3d=h_3d + ) return action @@ -110,7 +145,8 @@ def __init__( def forward(self, x, h_a=None, h_t=None, p= None): - + #* [B, A_chunk, A_dim * Dim] -> [B, A_chunk, Dim] -> [B, A_chunk, A_dim] + #* 每一个 block 内部的过程是: # x: (batch_size, input_dim) x = self.layer_norm1(x) # shape: (batch_size, input_dim) x = self.fc1(x) # shape: (batch_size, hidden_dim) @@ -121,6 +157,62 @@ def forward(self, x, h_a=None, h_t=None, p= None): x = self.fc2(x) # shape: (batch_size, output_dim) return x +class MLPResNetw3d(nn.Module): + """MLP with residual connection blocks.""" + def __init__( + self, + num_blocks, + input_dim, + hidden_dim, + output_dim, + use_pro_version=True, + feat_3d_dim=2048, + inject_layers=0 + ): + + super().__init__() + self.layer_norm1 = nn.LayerNorm(input_dim) + self.fc1 = nn.Linear(input_dim, hidden_dim) + self.relu = nn.ReLU() + self.mlp_resnet_blocks = nn.ModuleList() + # if use_3d_feat: + self.feat_3d_dim = feat_3d_dim + self.feat_3d_align = nn.Sequential( + nn.LayerNorm(self.feat_3d_dim), + nn.Linear(self.feat_3d_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim) + ) + self.inject_layers = inject_layers # TODO: inject 3D feat in only one layer + for i in range(num_blocks): + if self.inject_layers == "all": + self.mlp_resnet_blocks.append(MLPResNetBlock_Pro_w3d(dim=hidden_dim)) + elif isinstance(self.inject_layers, int) and i == self.inject_layers: + self.mlp_resnet_blocks.append(MLPResNetBlock_Pro_w3d(dim=hidden_dim)) + else: + self.mlp_resnet_blocks.append(MLPResNetBlock_Pro(dim=hidden_dim)) + + self.layer_norm2 = nn.LayerNorm(hidden_dim) + self.fc2 = nn.Linear(hidden_dim, output_dim) + + + def forward(self, x, h_a=None, h_t=None, p= None, h_3d=None): + #* [B, A_chunk, A_dim * Dim] -> [B, A_chunk, Dim] -> [B, A_chunk, A_dim] + #* 每一个 block 内部的过程是: + # x: (batch_size, input_dim) + h_3d = self.feat_3d_align(h_3d) + x = self.layer_norm1(x) # shape: (batch_size, input_dim) + x = self.fc1(x) # shape: (batch_size, hidden_dim) + x = self.relu(x) # shape: (batch_size, hidden_dim) + for i, block in enumerate(self.mlp_resnet_blocks): + if isinstance(block, MLPResNetBlock_Pro_w3d): + x = block(x, h_t = h_t[:,i+1,:], h_a = h_a[:,i+1,:], p=p, h_3d=h_3d) # shape: (batch_size, hidden_dim) + elif isinstance(block, MLPResNetBlock_Pro): + x = block(x, h_t = h_t[:,i+1,:], h_a = h_a[:,i+1,:], p=p) # shape: (batch_size, hidden_dim) + x = self.layer_norm2(x) # shape: (batch_size, hidden_dim) + x = self.fc2(x) # shape: (batch_size, output_dim) + return x + def apply_rope(q, k, cos, sin): @@ -340,6 +432,14 @@ def forward(self, x, h_a=None, h_t=None, p=None): h_a: adapter tokens h_t: task tokens p: possible conditioning vector (for FiLM) + * x: [B, A_chunk, Dim] + * h_a: [B, L_a, Dim] + * h_t: [B, L_v, Dim] + * p: [B, 1, Dim] + * 三种:[B, n, A_chunk, dim], [B, n, L_a + p, dim], [B, n, L_v, dim] MHA 方式,加入 RoPE + * [B, n, A_chunk, dim] 的 q 和 自身的 k、h_t 的 k、h_a 的 k 分别做点积,得到三个 + * [B, n, A_chunk, A_chunk], [B, n, A_chunk, L_a + p], [B, n, A_chunk, L_v] , cat 就是 [B, n, A_chunk, A_chunk + (L_a + p) + L_v] + * 而 v 三者 cat 在一起就是 [B, n, A_chunk + (L_a + p) + L_v, dim] --> [B, n, A_chunk, dim] """ g = self.gating_factor ratio_g = torch.tanh(g) @@ -409,3 +509,149 @@ def reshape_heads(t, B, L): # residual + FFN x = self.ffn(output + x) return x + +class MLPResNetBlock_Pro_w3d(nn.Module): + """One MLP ResNet block with separate projections for self, adapter, task + RoPE, now with FiLM modulation.""" + + def __init__(self, dim: int, num_heads: int=8) -> None: + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.ffn = nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, dim), + nn.ReLU(), + ) + + # Q (from x only) + self.q_proj = nn.Linear(dim, dim) + + # Self-Attention: K, V + self.k_self = nn.Linear(dim, dim) + self.v_self = nn.Linear(dim, dim) + + # Adapter cross-attention: K, V + self.k_adapter = nn.Linear(dim, dim) + self.v_adapter = nn.Linear(dim, dim) + + # Task cross-attention: K, V + self.k_task = nn.Linear(dim, dim) + self.v_task = nn.Linear(dim, dim) + + self.k_3d = nn.Linear(dim, dim) + self.v_3d = nn.Linear(dim, dim) + + self.o_proj = nn.Linear(dim, dim) + + # gating + self.gating_factor = nn.Parameter(torch.zeros(1)) + + # RoPE + self.rope = RotaryPositionEmbedding(self.head_dim) + + # ---- FiLM ---- + # FiLM is useless; to avoid conflict with chkpt, it can be kept as is for now. + self.film_gen = nn.Sequential( + nn.Linear(dim, dim * 2), # output γ and β + ) + + + def apply_film(self, x, gamma, beta): + """FiLM: per-channel modulation""" + return gamma.unsqueeze(1) * x + beta.unsqueeze(1) + + + def forward(self, x, h_a=None, h_t=None, p=None, h_3d=None): + """ + h_a: adapter tokens + h_t: task tokens + p: possible conditioning vector (for FiLM) + * x: [B, A_chunk, Dim] + * h_a: [B, L_a, Dim] + * h_t: [B, L_v, Dim] + * p: [B, 1, Dim] + * 三种:[B, n, A_chunk, dim], [B, n, L_a + p, dim], [B, n, L_v, dim] MHA 方式,加入 RoPE + * [B, n, A_chunk, dim] 的 q 和 自身的 k、h_t 的 k、h_a 的 k 分别做点积,得到三个 + * [B, n, A_chunk, A_chunk], [B, n, A_chunk, L_a + p], [B, n, A_chunk, L_v] , cat 就是 [B, n, A_chunk, A_chunk + (L_a + p) + L_v] + * 而 v 三者 cat 在一起就是 [B, n, A_chunk + (L_a + p) + L_v, dim] --> [B, n, A_chunk, dim] + """ + g = self.gating_factor + ratio_g = torch.tanh(g) + + # concat h_a and p + h_adapter = torch.cat((h_a, p),dim=1) + + + h_task = h_t + B, T, C = x.shape + K_a = h_adapter.size(1) if h_a is not None else 0 + K_t = h_task.size(1) if h_task is not None else 0 + K_3d = h_3d.size(1) if h_3d is not None else 0 + + # Q + q_1 = self.q_proj(x) + + # self tokens + k_tokens = self.k_self(x) + v_tokens = self.v_self(x) + + # adapter tokens + k_adapter = self.k_adapter(h_adapter) + v_adapter = self.v_adapter(h_adapter) + + # task tokens + k_task = self.k_task(h_task) + v_task = self.v_task(h_task) + + # 3D tokens + k_3d = self.k_3d(h_3d) + v_3d = self.v_3d(h_3d) + + + # reshape -> multi-head + def reshape_heads(t: torch.Tensor, B: int, L: int) -> torch.Tensor: + return t.view(B, L, self.num_heads, self.head_dim).transpose(1, 2) + + + q_1 = reshape_heads(q_1, B, T) + k_tokens, v_tokens = reshape_heads(k_tokens, B, T), reshape_heads(v_tokens, B, T) + k_adapter, v_adapter = reshape_heads(k_adapter, B, K_a), reshape_heads(v_adapter, B, K_a) + k_task, v_task = reshape_heads(k_task, B, K_t), reshape_heads(v_task, B, K_t) + k_3d, v_3d = reshape_heads(k_3d, B, K_3d), reshape_heads(v_3d, B, K_3d) + + # RoPE + cos_main, sin_main = self.rope(seq_len=T, device=x.device, dtype=x.dtype) + q_1, k_tokens = apply_rope(q_1, k_tokens, cos_main, sin_main) + cos_a, sin_a = self.rope(seq_len=K_a, device=x.device, dtype=x.dtype) + _, k_adapter = apply_rope(k_adapter, k_adapter, cos_a, sin_a) + cos_t, sin_t = self.rope(seq_len=K_t, device=x.device, dtype=x.dtype) + _, k_task = apply_rope(k_task, k_task, cos_t, sin_t) + cos3d, sin3d = self.rope(seq_len=K_3d, device=x.device, dtype=x.dtype) + _, k_3d = apply_rope(k_3d, k_3d, cos3d, sin3d) + + # attention scores + attn_scores = [torch.matmul(q_1, k_tokens.transpose(-2, -1))] + attn_scores.append(torch.matmul(q_1, k_adapter.transpose(-2, -1))) + attn_scores.append(torch.matmul(q_1, k_task.transpose(-2, -1)) * ratio_g) + attn_scores.append(torch.matmul(q_1, k_3d.transpose(-2, -1))) + attn_scores = torch.cat(attn_scores, dim=-1) / math.sqrt(self.head_dim) + attn_weights = torch.softmax(attn_scores, dim=-1) + + # combine V + v_list = [v_tokens, v_adapter, v_task, v_3d] + v_combined = torch.cat(v_list, dim=2) + + output = torch.matmul(attn_weights, v_combined) + output = output.transpose(1, 2).contiguous().view(B, T, C) + output = self.o_proj(output) + + # # ---- FiLM ---- + # gamma_beta = self.film_gen(p) # [B, 2C] + # gamma, beta = gamma_beta.chunk(2, dim=-1) # [B, C], [B, C] + # output = self.apply_film(output, gamma, beta) + + # residual + FFN + x = self.ffn(output + x) + return x \ No newline at end of file diff --git a/prismatic/models/backbones/llm/__pycache__/__init__.cpython-310.pyc b/prismatic/models/backbones/llm/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 1ee1bfc..0000000 Binary files a/prismatic/models/backbones/llm/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/prismatic/models/backbones/llm/__pycache__/base_llm.cpython-310.pyc b/prismatic/models/backbones/llm/__pycache__/base_llm.cpython-310.pyc deleted file mode 100644 index 704d339..0000000 Binary files a/prismatic/models/backbones/llm/__pycache__/base_llm.cpython-310.pyc and /dev/null differ diff --git a/prismatic/models/backbones/llm/__pycache__/llama2.cpython-310.pyc b/prismatic/models/backbones/llm/__pycache__/llama2.cpython-310.pyc deleted file mode 100644 index b462ca8..0000000 Binary files a/prismatic/models/backbones/llm/__pycache__/llama2.cpython-310.pyc and /dev/null differ diff --git a/prismatic/models/backbones/llm/__pycache__/mistral.cpython-310.pyc b/prismatic/models/backbones/llm/__pycache__/mistral.cpython-310.pyc deleted file mode 100644 index d7ad9c7..0000000 Binary files a/prismatic/models/backbones/llm/__pycache__/mistral.cpython-310.pyc and /dev/null differ diff --git a/prismatic/models/backbones/llm/__pycache__/phi.cpython-310.pyc b/prismatic/models/backbones/llm/__pycache__/phi.cpython-310.pyc deleted file mode 100644 index 6c50932..0000000 Binary files a/prismatic/models/backbones/llm/__pycache__/phi.cpython-310.pyc and /dev/null differ diff --git a/prismatic/models/backbones/llm/__pycache__/qwen25.cpython-310.pyc b/prismatic/models/backbones/llm/__pycache__/qwen25.cpython-310.pyc deleted file mode 100644 index e60b698..0000000 Binary files a/prismatic/models/backbones/llm/__pycache__/qwen25.cpython-310.pyc and /dev/null differ diff --git a/prismatic/models/backbones/llm/prompting/__pycache__/__init__.cpython-310.pyc b/prismatic/models/backbones/llm/prompting/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 9db5d29..0000000 Binary files a/prismatic/models/backbones/llm/prompting/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/prismatic/models/backbones/llm/prompting/__pycache__/base_prompter.cpython-310.pyc b/prismatic/models/backbones/llm/prompting/__pycache__/base_prompter.cpython-310.pyc deleted file mode 100644 index 76328d9..0000000 Binary files a/prismatic/models/backbones/llm/prompting/__pycache__/base_prompter.cpython-310.pyc and /dev/null differ diff --git a/prismatic/models/backbones/llm/prompting/__pycache__/llama2_chat_prompter.cpython-310.pyc b/prismatic/models/backbones/llm/prompting/__pycache__/llama2_chat_prompter.cpython-310.pyc deleted file mode 100644 index c928e24..0000000 Binary files a/prismatic/models/backbones/llm/prompting/__pycache__/llama2_chat_prompter.cpython-310.pyc and /dev/null differ diff --git a/prismatic/models/backbones/llm/prompting/__pycache__/mistral_instruct_prompter.cpython-310.pyc b/prismatic/models/backbones/llm/prompting/__pycache__/mistral_instruct_prompter.cpython-310.pyc deleted file mode 100644 index 53bfe1c..0000000 Binary files a/prismatic/models/backbones/llm/prompting/__pycache__/mistral_instruct_prompter.cpython-310.pyc and /dev/null differ diff --git a/prismatic/models/backbones/llm/prompting/__pycache__/phi_prompter.cpython-310.pyc b/prismatic/models/backbones/llm/prompting/__pycache__/phi_prompter.cpython-310.pyc deleted file mode 100644 index 180d546..0000000 Binary files a/prismatic/models/backbones/llm/prompting/__pycache__/phi_prompter.cpython-310.pyc and /dev/null differ diff --git a/prismatic/models/backbones/llm/prompting/__pycache__/qwen_prompter.cpython-310.pyc b/prismatic/models/backbones/llm/prompting/__pycache__/qwen_prompter.cpython-310.pyc deleted file mode 100644 index b5ad167..0000000 Binary files a/prismatic/models/backbones/llm/prompting/__pycache__/qwen_prompter.cpython-310.pyc and /dev/null differ diff --git a/prismatic/models/backbones/llm/prompting/__pycache__/vicuna_v15_prompter.cpython-310.pyc b/prismatic/models/backbones/llm/prompting/__pycache__/vicuna_v15_prompter.cpython-310.pyc deleted file mode 100644 index 52a3192..0000000 Binary files a/prismatic/models/backbones/llm/prompting/__pycache__/vicuna_v15_prompter.cpython-310.pyc and /dev/null differ diff --git a/prismatic/models/backbones/vision/__pycache__/__init__.cpython-310.pyc b/prismatic/models/backbones/vision/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 28c8c56..0000000 Binary files a/prismatic/models/backbones/vision/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/prismatic/models/backbones/vision/__pycache__/base_vision.cpython-310.pyc b/prismatic/models/backbones/vision/__pycache__/base_vision.cpython-310.pyc deleted file mode 100644 index c036729..0000000 Binary files a/prismatic/models/backbones/vision/__pycache__/base_vision.cpython-310.pyc and /dev/null differ diff --git a/prismatic/models/backbones/vision/__pycache__/clip_vit.cpython-310.pyc b/prismatic/models/backbones/vision/__pycache__/clip_vit.cpython-310.pyc deleted file mode 100644 index 62ac130..0000000 Binary files a/prismatic/models/backbones/vision/__pycache__/clip_vit.cpython-310.pyc and /dev/null differ diff --git a/prismatic/models/backbones/vision/__pycache__/dinoclip_vit.cpython-310.pyc b/prismatic/models/backbones/vision/__pycache__/dinoclip_vit.cpython-310.pyc deleted file mode 100644 index 1fa451c..0000000 Binary files a/prismatic/models/backbones/vision/__pycache__/dinoclip_vit.cpython-310.pyc and /dev/null differ diff --git a/prismatic/models/backbones/vision/__pycache__/dinosiglip_vit.cpython-310.pyc b/prismatic/models/backbones/vision/__pycache__/dinosiglip_vit.cpython-310.pyc deleted file mode 100644 index ac394c9..0000000 Binary files a/prismatic/models/backbones/vision/__pycache__/dinosiglip_vit.cpython-310.pyc and /dev/null differ diff --git a/prismatic/models/backbones/vision/__pycache__/dinov2_vit.cpython-310.pyc b/prismatic/models/backbones/vision/__pycache__/dinov2_vit.cpython-310.pyc deleted file mode 100644 index fbec061..0000000 Binary files a/prismatic/models/backbones/vision/__pycache__/dinov2_vit.cpython-310.pyc and /dev/null differ diff --git a/prismatic/models/backbones/vision/__pycache__/in1k_vit.cpython-310.pyc b/prismatic/models/backbones/vision/__pycache__/in1k_vit.cpython-310.pyc deleted file mode 100644 index 392ecaa..0000000 Binary files a/prismatic/models/backbones/vision/__pycache__/in1k_vit.cpython-310.pyc and /dev/null differ diff --git a/prismatic/models/backbones/vision/__pycache__/siglip_vit.cpython-310.pyc b/prismatic/models/backbones/vision/__pycache__/siglip_vit.cpython-310.pyc deleted file mode 100644 index a9cae53..0000000 Binary files a/prismatic/models/backbones/vision/__pycache__/siglip_vit.cpython-310.pyc and /dev/null differ diff --git a/prismatic/models/pc_encoder.py b/prismatic/models/pc_encoder.py new file mode 100644 index 0000000..9b1040c --- /dev/null +++ b/prismatic/models/pc_encoder.py @@ -0,0 +1,450 @@ + +""" +pc_encoder.py + +Implementations of pointcloud encoder in iDP3, which also supports 2D point cloud maps. + +# reference: https://github.com/YanjieZe/Improved-3D-Diffusion-Policy/blob/main/Improved-3D-Diffusion-Policy/diffusion_policy_3d/model/vision_3d +""" +from typing import Tuple, List, Optional, Dict, Union, Type +from pathlib import Path + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from prismatic.overwatch import initialize_overwatch +overwatch = initialize_overwatch(__name__) +# ==================== Utility Functions ==================== + +def meanpool(x, dim=-1, keepdim=False): + out = x.mean(dim=dim, keepdim=keepdim) + return out + +def maxpool(x, dim=-1, keepdim=False): + out = x.max(dim=dim, keepdim=keepdim).values + return out + +def shuffle_point_torch(point_cloud): + B, N, C = point_cloud.shape + indices = torch.randperm(N) + return point_cloud[:, indices] + +def pad_point_torch(point_cloud, num_points): + B, N, C = point_cloud.shape + device = point_cloud.device + if num_points > N: + num_pad = num_points - N + pad_points = torch.zeros(B, num_pad, C).to(device) + point_cloud = torch.cat([point_cloud, pad_points], dim=1) + point_cloud = shuffle_point_torch(point_cloud) + return point_cloud + +def uniform_sampling_torch(point_cloud, num_points): + B, N, C = point_cloud.shape + if num_points == N: + return point_cloud + if num_points > N: + return pad_point_torch(point_cloud, num_points) + + # random sampling + indices = torch.randperm(N)[:num_points] + sampled_points = point_cloud[:, indices] + return sampled_points + +# ==================== 2D Point Cloud utility ==================== + +def shuffle_map_torch(map_data): + B, H, W, C = map_data.shape + # Flatten spatial dimensions + map_flat = map_data.view(B, H * W, C) + indices = torch.randperm(H * W) + map_shuffled = map_flat[:, indices, :] + # Reshape back + return map_shuffled.view(B, H, W, C) + + +def pad_map_torch(map_data, target_size): + B, H, W, C = map_data.shape + device = map_data.device + + if target_size > H or target_size > W: + # Create zero-padded map + padded_map = torch.zeros(B, target_size, target_size, C, device=device) + # Copy original data to top-left corner + padded_map[:, :H, :W, :] = map_data + # Optionally shuffle to distribute zeros randomly + return shuffle_map_torch(padded_map) + + return map_data + + +def resize_map_torch(map_data, target_size): + B, H, W, C = map_data.shape + + # Convert to [B, C, H, W] for F.interpolate + map_permuted = map_data.permute(0, 3, 1, 2) + + # Resize + if isinstance(target_size, int): + target_size = (target_size, target_size) + + map_resized = F.interpolate( + map_permuted, + size=target_size, + mode='bilinear', + align_corners=False + ) + + # Convert back to [B, H, W, C] + return map_resized.permute(0, 2, 3, 1) + +def crop_map_torch(map_data, target_size): + B, H, W, C = map_data.shape + + if isinstance(target_size, int): + target_h = target_w = target_size + else: + target_h, target_w = target_size + + if H < target_h or W < target_w: + # If smaller than target, pad first + return pad_map_torch(map_data, max(target_h, target_w)) + + # Random crop + top = torch.randint(0, H - target_h + 1, (1,)).item() + left = torch.randint(0, W - target_w + 1, (1,)).item() + + return map_data[:, top:top+target_h, left:left+target_w, :] + + +def uniform_sampling_map_torch(map_data, target_size, method='resize'): + """ + Unified sampling function for 2D maps + Args: + map_data: [B, H, W, 3] + target_size: int or tuple (target_H, target_W) + method: 'resize', 'crop', or 'pad' + Returns: + sampled map: [B, target_H, target_W, 3] + """ + B, H, W, C = map_data.shape + + if isinstance(target_size, int): + target_h = target_w = target_size + else: + target_h, target_w = target_size + + # If already at target size, return as is + if H == target_h and W == target_w: + return map_data + + if method == 'resize': + return resize_map_torch(map_data, target_size) + elif method == 'crop': + return crop_map_torch(map_data, target_size) + elif method == 'pad': + return pad_map_torch(map_data, target_size) + else: + raise ValueError(f"Unknown method: {method}. Use 'resize', 'crop', or 'pad'.") + +# ==================== 1D Point Cloud Encoder ==================== + +class MultiStagePointNetEncoder(nn.Module): + """1D Point Cloud Encoder using 1D convolutions""" + def __init__(self, h_dim=128, out_channels=128, num_layers=4, **kwargs): + super().__init__() + + self.h_dim = h_dim + self.out_channels = out_channels + self.num_layers = num_layers + + self.act = nn.LeakyReLU() + + self.conv_in = nn.Conv1d(3, h_dim, kernel_size=1) + self.layers, self.global_layers = nn.ModuleList(), nn.ModuleList() + for _ in range(self.num_layers): + self.layers.append(nn.Conv1d(h_dim, h_dim, kernel_size=1)) + self.global_layers.append(nn.Conv1d(h_dim * 2, h_dim, kernel_size=1)) + self.conv_out = nn.Conv1d(h_dim * self.num_layers, out_channels, kernel_size=1) + + def forward(self, x): + # x: [B, L, 3] --> [B, 3, L] + assert x.shape[-1] == 3, f"Input shape must have 3 channels at the last dim, got{x.shape}" + x = x.transpose(1, 2) + y = self.act(self.conv_in(x)) + feat_list = [] + for i in range(self.num_layers): + y = self.act(self.layers[i](y)) + y_global = y.max(-1, keepdim=True).values + y = torch.cat([y, y_global.expand_as(y)], dim=1) + y = self.act(self.global_layers[i](y)) + feat_list.append(y) + x = torch.cat(feat_list, dim=1) + x = self.conv_out(x) + + x_global = x.max(-1).values # [B, out_channels] + + return x_global + + +# ==================== 2D Point Cloud Map Encoder ==================== + +class MultiStageMapNetEncoder(nn.Module): + """2D Point Cloud Map Encoder using 2D convolutions""" + def __init__(self, h_dim=128, out_channels=128, num_layers=4, **kwargs): + super().__init__() + + self.h_dim = h_dim + self.out_channels = out_channels + self.num_layers = num_layers + self.act = nn.LeakyReLU() + self.conv_in = nn.Conv2d(3, h_dim, kernel_size=3, padding=1) + + self.layers = nn.ModuleList() + self.global_layers = nn.ModuleList() + + for _ in range(self.num_layers): + self.layers.append(nn.Conv2d(h_dim, h_dim, kernel_size=3, stride=1, padding=1)) + self.global_layers.append(nn.Conv2d(h_dim * 2, h_dim, kernel_size=1, stride=1, padding=0)) + + self.conv_out = nn.Conv2d(h_dim * self.num_layers, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + # x: [B, 3, H, W] + y = self.act(self.conv_in(x)) + feat_list = [] + + for i in range(self.num_layers): + y = self.act(self.layers[i](y)) + y_global = F.adaptive_max_pool2d(y, 1) # [B, h_dim, 1, 1] + y_global_expanded = y_global.expand_as(y) + y = torch.cat([y, y_global_expanded], dim=1) + y = self.act(self.global_layers[i](y)) + feat_list.append(y) + + x = torch.cat(feat_list, dim=1) + x = self.conv_out(x) + + x_global = F.adaptive_max_pool2d(x, 1) # [B, out_channels, 1, 1] + x_global = x_global.squeeze(-1).squeeze(-1) # [B, out_channels] + + return x_global + + +# ==================== Unified iDP3 Encoder ==================== + +class iDP3Encoder(nn.Module): + """ + Unified Point Cloud Encoder + + Supports 4 input formats: + 1. [B, L, 3] - Single 1D point cloud per batch + 2. [B, N, L, 3] - Multiple 1D point clouds (N views) + 3. [B, H, W, 3] - Single 2D point cloud map per batch + 4. [B, N, H, W, 3] - Multiple 2D point cloud maps (N views) + + Output: + - For single input: [B, out_channels] + - For multi-view input: [B, N, out_channels] (encode each view separately) + """ + def __init__(self, + out_channels=128, + num_points=4096, # Only used for 1D point cloud downsampling + target_map_size=224, # Target size for 2D map processing + h_dim=128, + num_layers=4, + point_downsample=True, # Only for 1D point clouds + map_sampling_method='resize', # 'resize', 'crop', or 'pad' for 2D maps + ): + super().__init__() + self.n_output_channels = out_channels + self.num_points = num_points + self.downsample = point_downsample + self.target_map_size = target_map_size + self.map_sampling_method = map_sampling_method + + # 1D Point Cloud Encoder + self.pointnet_encoder = MultiStagePointNetEncoder( + h_dim=h_dim, + out_channels=out_channels, + num_layers=num_layers + ) + + # 2D Map Encoder + self.mapnet_encoder = MultiStageMapNetEncoder( + h_dim=h_dim, + out_channels=out_channels, + num_layers=num_layers + ) + + overwatch.info(f"iDP3 Encoder has num layers: {num_layers}, h_dim: {h_dim}, output dim: {self.n_output_channels}") + + def _encode_1d_pointcloud(self, pc): + if self.downsample: + pc = uniform_sampling_torch(pc, self.num_points) + return self.pointnet_encoder(pc) + + def _encode_2d_map(self, map_data): + # [B, H, W, 3] -> [B, 3, H, W] + if H != self.target_map_size or W != self.target_map_size: + map_data = uniform_sampling_map_torch( + map_data, + self.target_map_size, + method=self.map_sampling_method + ) + map_data = map_data.permute(0, 3, 1, 2) + return self.mapnet_encoder(map_data) + + def forward(self, x: torch.Tensor, multi_scene: bool = False) -> torch.Tensor: + """ + Args: + x: Point cloud tensor in one of these formats: + - [B, L, 3]: Single 1D point cloud + - [B, N, L, 3]: Multiple 1D point clouds + - [B, H, W, 3]: Single 2D point cloud map + - [B, N, H, W, 3]: Multiple 2D point cloud maps + + Returns: + features: + - [B, out_channels] for single input + - [B, N, out_channels] for multi-view input + """ + assert x.shape[-1] == 3, f"Last dimension must be 3 (XYZ), got {x.shape[-1]}" + + ndim = len(x.shape) + + if ndim == 3: + # Case 1: [B, L, 3] - Single 1D point cloud + return self._encode_1d_pointcloud(x) + + elif ndim == 4: + B, dim1, dim2, C = x.shape + # Distinguish between [B, N, L, 3] and [B, H, W, 3] + if multi_scene: + # Case 2: [B, N, L, 3] - Multiple 1D point clouds + B, N, L, C = x.shape + x_reshaped = x.view(B * N, L, C) # Reshape to [B*N, L, 3] + features = self._encode_1d_pointcloud(x_reshaped) # [B*N, out_channels] + features = features.view(B, N, -1) # Reshape back to [B, N, out_channels] + return features + + else: + # Case 3: [B, H, W, 3] - Single 2D point cloud map + return self._encode_2d_map(x) + + elif ndim == 5: + # Case 4: [B, N, H, W, 3] - Multiple 2D point cloud maps + B, N, H, W, C = x.shape + x_reshaped = x.view(B * N, H, W, C) + features = self._encode_2d_map(x_reshaped) # [B*N, out_channels] + features = features.view(B, N, -1) # Reshape back to [B, N, out_channels] + return features + + else: + raise ValueError(f"Unsupported input shape: {x.shape}. Expected 3, 4, or 5 dimensions.") + @property + def output_shape(self) -> int: + return self.n_output_channels + + +# ==================== Main Test ==================== + +if __name__ == "__main__": + print("="*70) + print("Testing Unified iDP3Encoder") + print("="*70) + + # Initialize encoder + encoder = iDP3Encoder( + out_channels=256, + num_points=4096, + h_dim=128, + num_layers=4, + point_downsample=True + ) + + print(f"\nTotal parameters: {sum(p.numel() for p in encoder.parameters()):,}") + encoder.eval() + + # ========== Case 1: [B, L, 3] - Single 1D point cloud ========== + print("\n" + "="*70) + print("Case 1: [B, L, 3] - Single 1D point cloud") + print("="*70) + + B, L = 4, 8192 + pc_1d = torch.randn(B, L, 3) + print(f"Input shape: {pc_1d.shape}") + + with torch.no_grad(): + out_1d = encoder(pc_1d) + + print(f"Output shape: {out_1d.shape}") + print(f"Expected: [{B}, {encoder.output_shape}]") + assert out_1d.shape == (B, 256), f"Shape mismatch! Got {out_1d.shape}" + print("✓ Test passed!") + + # ========== Case 2: [B, N, L, 3] - Multiple 1D point clouds ========== + print("\n" + "="*70) + print("Case 2: [B, N, L, 3] - Multiple 1D point clouds") + print("="*70) + + B, N, L = 4, 3, 8192 + pc_multi_1d = torch.randn(B, N, L, 3) + print(f"Input shape: {pc_multi_1d.shape}") + print(f"N={N} views, each with {L} points") + + with torch.no_grad(): + out_multi_1d = encoder(pc_multi_1d, multi_scene=True) + + print(f"Output shape: {out_multi_1d.shape}") + print(f"Expected: [{B}, {N}, {encoder.output_shape}]") + assert out_multi_1d.shape == (B, N, 256), f"Shape mismatch! Got {out_multi_1d.shape}" + print("✓ Test passed! Each view encoded separately.") + + # ========== Case 3: [B, H, W, 3] - Single 2D point cloud map ========== + print("\n" + "="*70) + print("Case 3: [B, H, W, 3] - Single 2D point cloud map") + print("="*70) + + B, H, W = 4, 224, 224 + map_2d = torch.randn(B, H, W, 3) + print(f"Input shape: {map_2d.shape}") + + with torch.no_grad(): + out_2d = encoder(map_2d, multi_scene=False) + + print(f"Output shape: {out_2d.shape}") + print(f"Expected: [{B}, {encoder.output_shape}]") + assert out_2d.shape == (B, 256), f"Shape mismatch! Got {out_2d.shape}" + print("✓ Test passed!") + + # ========== Case 4: [B, N, H, W, 3] - Multiple 2D point cloud maps ========== + print("\n" + "="*70) + print("Case 4: [B, N, H, W, 3] - Multiple 2D point cloud maps") + print("="*70) + + B, N, H, W = 4, 5, 224, 224 + map_multi_2d = torch.randn(B, N, H, W, 3) + print(f"Input shape: {map_multi_2d.shape}") + print(f"N={N} views, each with {H}x{W} resolution") + + with torch.no_grad(): + out_multi_2d = encoder(map_multi_2d) + + print(f"Output shape: {out_multi_2d.shape}") + print(f"Expected: [{B}, {N}, {encoder.output_shape}]") + assert out_multi_2d.shape == (B, N, 256), f"Shape mismatch! Got {out_multi_2d.shape}" + print("✓ Test passed! Each map encoded separately.") + + # ========== Summary ========== + print("\n" + "="*70) + print("Summary of All Test Cases") + print("="*70) + print(f"Case 1: [B, L, 3] → [{B}, {encoder.output_shape}]") + print(f"Case 2: [B, N, L, 3] → [{B}, {N}, {encoder.output_shape}] ({N} views)") + print(f"Case 3: [B, H, W, 3] → [{B}, {encoder.output_shape}]") + print(f"Case 4: [B, N, H, W, 3] → [{B}, {N}, {encoder.output_shape}] ({N} maps)") + print("\n✨ All tests passed! Unified encoder works correctly for all cases.") + print("="*70) \ No newline at end of file diff --git a/prismatic/models/pi3/models/dinov2/__init__.py b/prismatic/models/pi3/models/dinov2/__init__.py new file mode 100644 index 0000000..ae847e4 --- /dev/null +++ b/prismatic/models/pi3/models/dinov2/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +__version__ = "0.0.1" diff --git a/prismatic/models/pi3/models/dinov2/hub/__init__.py b/prismatic/models/pi3/models/dinov2/hub/__init__.py new file mode 100644 index 0000000..b88da6b --- /dev/null +++ b/prismatic/models/pi3/models/dinov2/hub/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. diff --git a/prismatic/models/pi3/models/dinov2/hub/backbones.py b/prismatic/models/pi3/models/dinov2/hub/backbones.py new file mode 100644 index 0000000..53fe837 --- /dev/null +++ b/prismatic/models/pi3/models/dinov2/hub/backbones.py @@ -0,0 +1,156 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from enum import Enum +from typing import Union + +import torch + +from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name + + +class Weights(Enum): + LVD142M = "LVD142M" + + +def _make_dinov2_model( + *, + arch_name: str = "vit_large", + img_size: int = 518, + patch_size: int = 14, + init_values: float = 1.0, + ffn_layer: str = "mlp", + block_chunks: int = 0, + num_register_tokens: int = 0, + interpolate_antialias: bool = False, + interpolate_offset: float = 0.1, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.LVD142M, + **kwargs, +): + from ..models import vision_transformer as vits + + if isinstance(weights, str): + try: + weights = Weights[weights] + except KeyError: + raise AssertionError(f"Unsupported weights: {weights}") + + model_base_name = _make_dinov2_model_name(arch_name, patch_size) + vit_kwargs = dict( + img_size=img_size, + patch_size=patch_size, + init_values=init_values, + ffn_layer=ffn_layer, + block_chunks=block_chunks, + num_register_tokens=num_register_tokens, + interpolate_antialias=interpolate_antialias, + interpolate_offset=interpolate_offset, + ) + vit_kwargs.update(**kwargs) + model = vits.__dict__[arch_name](**vit_kwargs) + + if pretrained: + model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens) + url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth" + state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") + model.load_state_dict(state_dict, strict=True) + + return model + + +def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_giant2", + ffn_layer="swiglufused", + weights=weights, + pretrained=pretrained, + **kwargs, + ) + + +def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_small", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_base", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_large", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_giant2", + ffn_layer="swiglufused", + weights=weights, + pretrained=pretrained, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) diff --git a/prismatic/models/pi3/models/dinov2/hub/utils.py b/prismatic/models/pi3/models/dinov2/hub/utils.py new file mode 100644 index 0000000..9c66414 --- /dev/null +++ b/prismatic/models/pi3/models/dinov2/hub/utils.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import itertools +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2" + + +def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str: + compact_arch_name = arch_name.replace("_", "")[:4] + registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else "" + return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}" + + +class CenterPadding(nn.Module): + def __init__(self, multiple): + super().__init__() + self.multiple = multiple + + def _get_pad(self, size): + new_size = math.ceil(size / self.multiple) * self.multiple + pad_size = new_size - size + pad_size_left = pad_size // 2 + pad_size_right = pad_size - pad_size_left + return pad_size_left, pad_size_right + + @torch.inference_mode() + def forward(self, x): + pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1])) + output = F.pad(x, pads) + return output diff --git a/prismatic/models/pi3/models/dinov2/layers/__init__.py b/prismatic/models/pi3/models/dinov2/layers/__init__.py new file mode 100644 index 0000000..05a0b61 --- /dev/null +++ b/prismatic/models/pi3/models/dinov2/layers/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .dino_head import DINOHead +from .mlp import Mlp +from .patch_embed import PatchEmbed +from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused +from .block import NestedTensorBlock +from .attention import MemEffAttention diff --git a/prismatic/models/pi3/models/dinov2/layers/attention.py b/prismatic/models/pi3/models/dinov2/layers/attention.py new file mode 100644 index 0000000..3fed573 --- /dev/null +++ b/prismatic/models/pi3/models/dinov2/layers/attention.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging +import os +import warnings + +from torch import Tensor +from torch import nn + + +logger = logging.getLogger("dinov2") + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import memory_efficient_attention, unbind + + XFORMERS_AVAILABLE = True + # warnings.warn("xFormers is available (Attention)") + else: + # warnings.warn("xFormers is disabled (Attention)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + # warnings.warn("xFormers is not available (Attention)") + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + if not XFORMERS_AVAILABLE: + if attn_bias is not None: + raise AssertionError("xFormers is required for using nested tensors") + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x diff --git a/prismatic/models/pi3/models/dinov2/layers/block.py b/prismatic/models/pi3/models/dinov2/layers/block.py new file mode 100644 index 0000000..fd5b8a7 --- /dev/null +++ b/prismatic/models/pi3/models/dinov2/layers/block.py @@ -0,0 +1,259 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import logging +import os +from typing import Callable, List, Any, Tuple, Dict +import warnings + +import torch +from torch import nn, Tensor + +from .attention import Attention, MemEffAttention +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp + + +logger = logging.getLogger("dinov2") + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import fmha, scaled_index_add, index_select_cat + + XFORMERS_AVAILABLE = True + # warnings.warn("xFormers is available (Block)") + else: + # warnings.warn("xFormers is disabled (Block)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + # warnings.warn("xFormers is not available (Block)") + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x))) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + else: + x_plus_residual = scaled_index_add( + x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): + outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list) + elif isinstance(x_or_x_list, list): + if not XFORMERS_AVAILABLE: + raise AssertionError("xFormers is required for using nested tensors") + return self.forward_nested(x_or_x_list) + else: + raise AssertionError diff --git a/prismatic/models/pi3/models/dinov2/layers/dino_head.py b/prismatic/models/pi3/models/dinov2/layers/dino_head.py new file mode 100644 index 0000000..0ace8ff --- /dev/null +++ b/prismatic/models/pi3/models/dinov2/layers/dino_head.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.nn.init import trunc_normal_ +from torch.nn.utils import weight_norm + + +class DINOHead(nn.Module): + def __init__( + self, + in_dim, + out_dim, + use_bn=False, + nlayers=3, + hidden_dim=2048, + bottleneck_dim=256, + mlp_bias=True, + ): + super().__init__() + nlayers = max(nlayers, 1) + self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias) + self.apply(self._init_weights) + self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) + self.last_layer.weight_g.data.fill_(1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.mlp(x) + eps = 1e-6 if x.dtype == torch.float16 else 1e-12 + x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) + x = self.last_layer(x) + return x + + +def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True): + if nlayers == 1: + return nn.Linear(in_dim, bottleneck_dim, bias=bias) + else: + layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + for _ in range(nlayers - 2): + layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) + return nn.Sequential(*layers) diff --git a/prismatic/models/pi3/models/dinov2/layers/drop_path.py b/prismatic/models/pi3/models/dinov2/layers/drop_path.py new file mode 100644 index 0000000..1d640e0 --- /dev/null +++ b/prismatic/models/pi3/models/dinov2/layers/drop_path.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + + +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/prismatic/models/pi3/models/dinov2/layers/layer_scale.py b/prismatic/models/pi3/models/dinov2/layers/layer_scale.py new file mode 100644 index 0000000..51df0d7 --- /dev/null +++ b/prismatic/models/pi3/models/dinov2/layers/layer_scale.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 + +from typing import Union + +import torch +from torch import Tensor +from torch import nn + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/prismatic/models/pi3/models/dinov2/layers/mlp.py b/prismatic/models/pi3/models/dinov2/layers/mlp.py new file mode 100644 index 0000000..bbf9432 --- /dev/null +++ b/prismatic/models/pi3/models/dinov2/layers/mlp.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py + + +from typing import Callable, Optional + +from torch import Tensor, nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/prismatic/models/pi3/models/dinov2/layers/patch_embed.py b/prismatic/models/pi3/models/dinov2/layers/patch_embed.py new file mode 100644 index 0000000..8b7c080 --- /dev/null +++ b/prismatic/models/pi3/models/dinov2/layers/patch_embed.py @@ -0,0 +1,88 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union + +from torch import Tensor +import torch.nn as nn + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" + assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/prismatic/models/pi3/models/dinov2/layers/swiglu_ffn.py b/prismatic/models/pi3/models/dinov2/layers/swiglu_ffn.py new file mode 100644 index 0000000..5ce2115 --- /dev/null +++ b/prismatic/models/pi3/models/dinov2/layers/swiglu_ffn.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import os +from typing import Callable, Optional +import warnings + +from torch import Tensor, nn +import torch.nn.functional as F + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import SwiGLU + + XFORMERS_AVAILABLE = True + # warnings.warn("xFormers is available (SwiGLU)") + else: + # warnings.warn("xFormers is disabled (SwiGLU)") + raise ImportError +except ImportError: + SwiGLU = SwiGLUFFN + XFORMERS_AVAILABLE = False + + # warnings.warn("xFormers is not available (SwiGLU)") + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) diff --git a/prismatic/models/pi3/models/dinov2/models/__init__.py b/prismatic/models/pi3/models/dinov2/models/__init__.py new file mode 100644 index 0000000..3fdff20 --- /dev/null +++ b/prismatic/models/pi3/models/dinov2/models/__init__.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging + +from . import vision_transformer as vits + + +logger = logging.getLogger("dinov2") + + +def build_model(args, only_teacher=False, img_size=224): + args.arch = args.arch.removesuffix("_memeff") + if "vit" in args.arch: + vit_kwargs = dict( + img_size=img_size, + patch_size=args.patch_size, + init_values=args.layerscale, + ffn_layer=args.ffn_layer, + block_chunks=args.block_chunks, + qkv_bias=args.qkv_bias, + proj_bias=args.proj_bias, + ffn_bias=args.ffn_bias, + num_register_tokens=args.num_register_tokens, + interpolate_offset=args.interpolate_offset, + interpolate_antialias=args.interpolate_antialias, + ) + teacher = vits.__dict__[args.arch](**vit_kwargs) + if only_teacher: + return teacher, teacher.embed_dim + student = vits.__dict__[args.arch]( + **vit_kwargs, + drop_path_rate=args.drop_path_rate, + drop_path_uniform=args.drop_path_uniform, + ) + embed_dim = student.embed_dim + return student, teacher, embed_dim + + +def build_model_from_cfg(cfg, only_teacher=False): + return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size) diff --git a/prismatic/models/pi3/models/dinov2/models/vision_transformer.py b/prismatic/models/pi3/models/dinov2/models/vision_transformer.py new file mode 100644 index 0000000..73f15cf --- /dev/null +++ b/prismatic/models/pi3/models/dinov2/models/vision_transformer.py @@ -0,0 +1,404 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +from functools import partial +import math +import logging +from typing import Sequence, Tuple, Union, Callable + +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint +from torch.nn.init import trunc_normal_ + +from ..layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block +from ...layers.attention import FlashAttention + + +# logger = logging.getLogger("dinov2") + + +def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + num_register_tokens: (int) number of extra cls tokens (so-called "registers") + interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings + interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + if ffn_layer == "mlp": + # logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + # logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + # logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + attn_class=FlashAttention + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + if self.register_tokens is not None: + nn.init.normal_(self.register_tokens, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + M = int(math.sqrt(N)) # Recover the number of patches in each dimension + assert N == M * M + kwargs = {} + if self.interpolate_offset: + # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8 + # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors + sx = float(w0 + self.interpolate_offset) / M + sy = float(h0 + self.interpolate_offset) / M + kwargs["scale_factor"] = (sx, sy) + else: + # Simply specify an output size instead of a scale factor + kwargs["size"] = (w0, h0) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), + mode="bicubic", + antialias=self.interpolate_antialias, + **kwargs, + ) + assert (w0, h0) == patch_pos_embed.shape[-2:] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + if self.register_tokens is not None: + x = torch.cat( + ( + x[:, :1], + self.register_tokens.expand(x.shape[0], -1, -1), + x[:, 1:], + ), + dim=1, + ) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + for blk in self.blocks: + if self.training: + x = checkpoint(blk, x, use_reentrant=False) + else: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + x = self.prepare_tokens_with_masks(x, masks) + + for blk in self.blocks: + if self.training: + x = checkpoint(blk, x, use_reentrant=False) + else: + x = blk(x) + + x_norm = self.norm(x) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=False, **kwargs): + ret = self.forward_features(*args, **kwargs) + if is_training: + return ret + else: + return self.head(ret["x_norm_clstoken"]) + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_base(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_large(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model diff --git a/prismatic/models/pi3/models/dinov2/utils/__init__.py b/prismatic/models/pi3/models/dinov2/utils/__init__.py new file mode 100644 index 0000000..b88da6b --- /dev/null +++ b/prismatic/models/pi3/models/dinov2/utils/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. diff --git a/prismatic/models/pi3/models/dinov2/utils/cluster.py b/prismatic/models/pi3/models/dinov2/utils/cluster.py new file mode 100644 index 0000000..3df87dc --- /dev/null +++ b/prismatic/models/pi3/models/dinov2/utils/cluster.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from enum import Enum +import os +from pathlib import Path +from typing import Any, Dict, Optional + + +class ClusterType(Enum): + AWS = "aws" + FAIR = "fair" + RSC = "rsc" + + +def _guess_cluster_type() -> ClusterType: + uname = os.uname() + if uname.sysname == "Linux": + if uname.release.endswith("-aws"): + # Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws" + return ClusterType.AWS + elif uname.nodename.startswith("rsc"): + # Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc" + return ClusterType.RSC + + return ClusterType.FAIR + + +def get_cluster_type(cluster_type: Optional[ClusterType] = None) -> Optional[ClusterType]: + if cluster_type is None: + return _guess_cluster_type() + + return cluster_type + + +def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]: + cluster_type = get_cluster_type(cluster_type) + if cluster_type is None: + return None + + CHECKPOINT_DIRNAMES = { + ClusterType.AWS: "checkpoints", + ClusterType.FAIR: "checkpoint", + ClusterType.RSC: "checkpoint/dino", + } + return Path("/") / CHECKPOINT_DIRNAMES[cluster_type] + + +def get_user_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]: + checkpoint_path = get_checkpoint_path(cluster_type) + if checkpoint_path is None: + return None + + username = os.environ.get("USER") + assert username is not None + return checkpoint_path / username + + +def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]: + cluster_type = get_cluster_type(cluster_type) + if cluster_type is None: + return None + + SLURM_PARTITIONS = { + ClusterType.AWS: "learnlab", + ClusterType.FAIR: "learnlab", + ClusterType.RSC: "learn", + } + return SLURM_PARTITIONS[cluster_type] + + +def get_slurm_executor_parameters( + nodes: int, num_gpus_per_node: int, cluster_type: Optional[ClusterType] = None, **kwargs +) -> Dict[str, Any]: + # create default parameters + params = { + "mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html + "gpus_per_node": num_gpus_per_node, + "tasks_per_node": num_gpus_per_node, # one task per GPU + "cpus_per_task": 10, + "nodes": nodes, + "slurm_partition": get_slurm_partition(cluster_type), + } + # apply cluster-specific adjustments + cluster_type = get_cluster_type(cluster_type) + if cluster_type == ClusterType.AWS: + params["cpus_per_task"] = 12 + del params["mem_gb"] + elif cluster_type == ClusterType.RSC: + params["cpus_per_task"] = 12 + # set additional parameters / apply overrides + params.update(kwargs) + return params diff --git a/prismatic/models/pi3/models/dinov2/utils/config.py b/prismatic/models/pi3/models/dinov2/utils/config.py new file mode 100644 index 0000000..c9de578 --- /dev/null +++ b/prismatic/models/pi3/models/dinov2/utils/config.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import math +import logging +import os + +from omegaconf import OmegaConf + +import dinov2.distributed as distributed +from dinov2.logging import setup_logging +from dinov2.utils import utils +from dinov2.configs import dinov2_default_config + + +logger = logging.getLogger("dinov2") + + +def apply_scaling_rules_to_cfg(cfg): # to fix + if cfg.optim.scaling_rule == "sqrt_wrt_1024": + base_lr = cfg.optim.base_lr + cfg.optim.lr = base_lr + cfg.optim.lr *= math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0) + logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}") + else: + raise NotImplementedError + return cfg + + +def write_config(cfg, output_dir, name="config.yaml"): + logger.info(OmegaConf.to_yaml(cfg)) + saved_cfg_path = os.path.join(output_dir, name) + with open(saved_cfg_path, "w") as f: + OmegaConf.save(config=cfg, f=f) + return saved_cfg_path + + +def get_cfg_from_args(args): + args.output_dir = os.path.abspath(args.output_dir) + args.opts += [f"train.output_dir={args.output_dir}"] + default_cfg = OmegaConf.create(dinov2_default_config) + cfg = OmegaConf.load(args.config_file) + cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts)) + return cfg + + +def default_setup(args): + distributed.enable(overwrite=True) + seed = getattr(args, "seed", 0) + rank = distributed.get_global_rank() + + global logger + setup_logging(output=args.output_dir, level=logging.INFO) + logger = logging.getLogger("dinov2") + + utils.fix_random_seeds(seed + rank) + logger.info("git:\n {}\n".format(utils.get_sha())) + logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) + + +def setup(args): + """ + Create configs and perform basic setups. + """ + cfg = get_cfg_from_args(args) + os.makedirs(args.output_dir, exist_ok=True) + default_setup(args) + apply_scaling_rules_to_cfg(cfg) + write_config(cfg, args.output_dir) + return cfg diff --git a/prismatic/models/pi3/models/dinov2/utils/dtype.py b/prismatic/models/pi3/models/dinov2/utils/dtype.py new file mode 100644 index 0000000..80f4cd7 --- /dev/null +++ b/prismatic/models/pi3/models/dinov2/utils/dtype.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + + +from typing import Dict, Union + +import numpy as np +import torch + + +TypeSpec = Union[str, np.dtype, torch.dtype] + + +_NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = { + np.dtype("bool"): torch.bool, + np.dtype("uint8"): torch.uint8, + np.dtype("int8"): torch.int8, + np.dtype("int16"): torch.int16, + np.dtype("int32"): torch.int32, + np.dtype("int64"): torch.int64, + np.dtype("float16"): torch.float16, + np.dtype("float32"): torch.float32, + np.dtype("float64"): torch.float64, + np.dtype("complex64"): torch.complex64, + np.dtype("complex128"): torch.complex128, +} + + +def as_torch_dtype(dtype: TypeSpec) -> torch.dtype: + if isinstance(dtype, torch.dtype): + return dtype + if isinstance(dtype, str): + dtype = np.dtype(dtype) + assert isinstance(dtype, np.dtype), f"Expected an instance of nunpy dtype, got {type(dtype)}" + return _NUMPY_TO_TORCH_DTYPE[dtype] diff --git a/prismatic/models/pi3/models/dinov2/utils/param_groups.py b/prismatic/models/pi3/models/dinov2/utils/param_groups.py new file mode 100644 index 0000000..9a5d2ff --- /dev/null +++ b/prismatic/models/pi3/models/dinov2/utils/param_groups.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from collections import defaultdict +import logging + + +logger = logging.getLogger("dinov2") + + +def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12, force_is_backbone=False, chunked_blocks=False): + """ + Calculate lr decay rate for different ViT blocks. + Args: + name (string): parameter name. + lr_decay_rate (float): base lr decay rate. + num_layers (int): number of ViT blocks. + Returns: + lr decay rate for the given parameter. + """ + layer_id = num_layers + 1 + if name.startswith("backbone") or force_is_backbone: + if ( + ".pos_embed" in name + or ".patch_embed" in name + or ".mask_token" in name + or ".cls_token" in name + or ".register_tokens" in name + ): + layer_id = 0 + elif force_is_backbone and ( + "pos_embed" in name + or "patch_embed" in name + or "mask_token" in name + or "cls_token" in name + or "register_tokens" in name + ): + layer_id = 0 + elif ".blocks." in name and ".residual." not in name: + layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1 + elif chunked_blocks and "blocks." in name and "residual." not in name: + layer_id = int(name[name.find("blocks.") :].split(".")[2]) + 1 + elif "blocks." in name and "residual." not in name: + layer_id = int(name[name.find("blocks.") :].split(".")[1]) + 1 + + return lr_decay_rate ** (num_layers + 1 - layer_id) + + +def get_params_groups_with_decay(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0): + chunked_blocks = False + if hasattr(model, "n_blocks"): + logger.info("chunked fsdp") + n_blocks = model.n_blocks + chunked_blocks = model.chunked_blocks + elif hasattr(model, "blocks"): + logger.info("first code branch") + n_blocks = len(model.blocks) + elif hasattr(model, "backbone"): + logger.info("second code branch") + n_blocks = len(model.backbone.blocks) + else: + logger.info("else code branch") + n_blocks = 0 + all_param_groups = [] + + for name, param in model.named_parameters(): + name = name.replace("_fsdp_wrapped_module.", "") + if not param.requires_grad: + continue + decay_rate = get_vit_lr_decay_rate( + name, lr_decay_rate, num_layers=n_blocks, force_is_backbone=n_blocks > 0, chunked_blocks=chunked_blocks + ) + d = {"params": param, "is_last_layer": False, "lr_multiplier": decay_rate, "wd_multiplier": 1.0, "name": name} + + if "last_layer" in name: + d.update({"is_last_layer": True}) + + if name.endswith(".bias") or "norm" in name or "gamma" in name: + d.update({"wd_multiplier": 0.0}) + + if "patch_embed" in name: + d.update({"lr_multiplier": d["lr_multiplier"] * patch_embed_lr_mult}) + + all_param_groups.append(d) + logger.info(f"""{name}: lr_multiplier: {d["lr_multiplier"]}, wd_multiplier: {d["wd_multiplier"]}""") + + return all_param_groups + + +def fuse_params_groups(all_params_groups, keys=("lr_multiplier", "wd_multiplier", "is_last_layer")): + fused_params_groups = defaultdict(lambda: {"params": []}) + for d in all_params_groups: + identifier = "" + for k in keys: + identifier += k + str(d[k]) + "_" + + for k in keys: + fused_params_groups[identifier][k] = d[k] + fused_params_groups[identifier]["params"].append(d["params"]) + + return fused_params_groups.values() diff --git a/prismatic/models/pi3/models/dinov2/utils/utils.py b/prismatic/models/pi3/models/dinov2/utils/utils.py new file mode 100644 index 0000000..e8842e4 --- /dev/null +++ b/prismatic/models/pi3/models/dinov2/utils/utils.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging +import os +import random +import subprocess +from urllib.parse import urlparse + +import numpy as np +import torch +from torch import nn + + +# logger = logging.getLogger("dinov2") + + +def load_pretrained_weights(model, pretrained_weights, checkpoint_key): + if urlparse(pretrained_weights).scheme: # If it looks like an URL + state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu") + else: + state_dict = torch.load(pretrained_weights, map_location="cpu") + if checkpoint_key is not None and checkpoint_key in state_dict: + # logger.info(f"Take key {checkpoint_key} in provided checkpoint dict") + state_dict = state_dict[checkpoint_key] + # remove `module.` prefix + state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} + # remove `backbone.` prefix induced by multicrop wrapper + state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} + msg = model.load_state_dict(state_dict, strict=False) + # logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg)) + + +def fix_random_seeds(seed=31): + """ + Fix random seeds. + """ + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + + +def get_sha(): + cwd = os.path.dirname(os.path.abspath(__file__)) + + def _run(command): + return subprocess.check_output(command, cwd=cwd).decode("ascii").strip() + + sha = "N/A" + diff = "clean" + branch = "N/A" + try: + sha = _run(["git", "rev-parse", "HEAD"]) + subprocess.check_output(["git", "diff"], cwd=cwd) + diff = _run(["git", "diff-index", "HEAD"]) + diff = "has uncommitted changes" if diff else "clean" + branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"]) + except Exception: + pass + message = f"sha: {sha}, status: {diff}, branch: {branch}" + return message + + +class CosineScheduler(object): + def __init__(self, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0): + super().__init__() + self.final_value = final_value + self.total_iters = total_iters + + freeze_schedule = np.zeros((freeze_iters)) + + warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) + + iters = np.arange(total_iters - warmup_iters - freeze_iters) + schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) + self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule)) + + assert len(self.schedule) == self.total_iters + + def __getitem__(self, it): + if it >= self.total_iters: + return self.final_value + else: + return self.schedule[it] + + +def has_batchnorms(model): + bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) + for name, module in model.named_modules(): + if isinstance(module, bn_types): + return True + return False diff --git a/prismatic/models/pi3/models/layers/attention.py b/prismatic/models/pi3/models/layers/attention.py new file mode 100644 index 0000000..ca7702b --- /dev/null +++ b/prismatic/models/pi3/models/layers/attention.py @@ -0,0 +1,377 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging +import os +import warnings + +from torch import Tensor +from torch import nn +import torch + +from torch.nn.functional import scaled_dot_product_attention +from torch.backends.cuda import SDPBackend + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import memory_efficient_attention, unbind + + XFORMERS_AVAILABLE = True + # warnings.warn("xFormers is available (Attention)") + else: + # warnings.warn("xFormers is disabled (Attention)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + # warnings.warn("xFormers is not available (Attention)") + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + if not XFORMERS_AVAILABLE: + if attn_bias is not None: + raise AssertionError("xFormers is required for using nested tensors") + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + # q, k, v = unbind(qkv, 2) + q, k, v = [qkv[:,:,i] for i in range(3)] + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + + +class FlashAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).transpose(1, 3) + + # q, k, v = unbind(qkv, 2) + q, k, v = [qkv[:,:,i] for i in range(3)] + + if q.dtype == torch.bfloat16: + with torch.backends.cuda.sdp_kernel(enable_flash=True, + enable_mem_efficient=False, + enable_math=False): + x = scaled_dot_product_attention(q, k, v) + else: + with torch.backends.cuda.sdp_kernel(enable_flash=False, + enable_mem_efficient=True, + enable_math=True): + x = scaled_dot_product_attention(q, k, v) + + x = x.transpose(1, 2).reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +""" +Following is written by GPT-4o +""" +class CrossAttentionRope(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + qk_norm: bool = False, + norm_layer: nn.Module = nn.LayerNorm, + rope=None, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + # Separate projection layers for query, key, and value + self.q_proj = nn.Linear(dim, dim, bias=qkv_bias) + self.k_proj = nn.Linear(dim, dim, bias=qkv_bias) + self.v_proj = nn.Linear(dim, dim, bias=qkv_bias) + + self.q_norm = norm_layer(head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(head_dim) if qk_norm else nn.Identity() + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + self.rope = rope + + def forward(self, query: Tensor, key: Tensor, value: Tensor, attn_bias=None, qpos=None, kpos=None) -> Tensor: + """ + Args: + query: Tensor of shape (B, N, C), input query + key: Tensor of shape (B, M, C), input key + value: Tensor of shape (B, M, C), input value + attn_bias: Optional tensor for attention bias + Returns: + Tensor of shape (B, N, C), output of cross-attention + """ + B, N, C = query.shape + _, M, _ = key.shape + + # Project query, key, and value + q = self.q_proj(query).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + k = self.k_proj(key).reshape(B, M, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + v = self.v_proj(value).reshape(B, M, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype) + + if self.rope is not None: + q = self.rope(q, qpos) + k = self.rope(k, kpos) + + # Scale query + q = q * self.scale + + # Compute attention scores + attn = q @ k.transpose(-2, -1) # (B, num_heads, N, M) + if attn_bias is not None: + attn = attn + attn_bias + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + # Compute attention output + x = (attn @ v).transpose(1, 2).reshape(B, N, C) # (B, N, C) + + # Final projection + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffCrossAttentionRope(CrossAttentionRope): + def forward(self, query: Tensor, key: Tensor, value: Tensor, attn_bias=None, qpos=None, kpos=None) -> Tensor: + """ + Args: + query: Tensor of shape (B, N, C), input query + key: Tensor of shape (B, M, C), input key + value: Tensor of shape (B, M, C), input value + attn_bias: Optional tensor for attention bias + Returns: + Tensor of shape (B, N, C), output of cross-attention + """ + if not XFORMERS_AVAILABLE: + if attn_bias is not None: + raise AssertionError("xFormers is required for using nested tensors") + return super().forward(query, key, value, attn_bias) + + B, N, C = query.shape + _, M, _ = key.shape + + # Project query, key, and value + q = self.q_proj(query).reshape(B, N, self.num_heads, C // self.num_heads) + k = self.k_proj(key).reshape(B, M, self.num_heads, C // self.num_heads) + v = self.v_proj(value).reshape(B, M, self.num_heads, C // self.num_heads) + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype) + + if self.rope is not None: + q = self.rope(q, qpos) + k = self.rope(k, kpos) + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + + # Compute memory-efficient attention + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape(B, N, C) + + # Final projection + x = self.proj(x) + x = self.proj_drop(x) + return x + +class AttentionRope(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + qk_norm: bool = False, + norm_layer: nn.Module = nn.LayerNorm, + rope=None + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + self.q_norm = norm_layer(head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(head_dim) if qk_norm else nn.Identity() + + self.rope = rope + + def forward(self, x: Tensor, attn_bias=None, xpos=None) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype) + + if self.rope is not None: + q = self.rope(q, xpos) + k = self.rope(k, xpos) + + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttentionRope(AttentionRope): + def forward(self, x: Tensor, attn_bias=None, xpos=None) -> Tensor: + if not XFORMERS_AVAILABLE: + if attn_bias is not None: + raise AssertionError("xFormers is required for using nested tensors") + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + qkv = qkv.transpose(1, 3) + # q, k, v = unbind(qkv, 2) + q, k, v = [qkv[:,:,i] for i in range(3)] + q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype) + + if self.rope is not None: + q = self.rope(q, xpos) + k = self.rope(k, xpos) + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + # score_matrix = (q.permute(0, 2, 1, 3) * self.scale @ k.permute(0, 2, 1, 3).transpose(-2, -1)).sum(dim=1).reshape(frame_num, 261, frame_num, 261).mean(dim=[1, 3]).sum(1) # for frame attention matrix + # global_valid_id = torch.where(score_matrix > 0) + # score_matrix = (q.permute(0, 2, 1, 3) * self.scale @ k.permute(0, 2, 1, 3).transpose(-2, -1)).sum(dim=1) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class FlashAttentionRope(AttentionRope): + def forward(self, x: Tensor, attn_bias=None, xpos=None) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).transpose(1, 3) + + # q, k, v = unbind(qkv, 2) + q, k, v = [qkv[:,:,i] for i in range(3)] + q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype) + + if self.rope is not None: + q = self.rope(q, xpos) + k = self.rope(k, xpos) + + if q.dtype == torch.bfloat16: + with torch.backends.cuda.sdp_kernel(enable_flash=True, + enable_mem_efficient=False, + enable_math=False): + x = scaled_dot_product_attention(q, k, v) + else: + with torch.backends.cuda.sdp_kernel(enable_flash=False, + enable_mem_efficient=True, + enable_math=True): + x = scaled_dot_product_attention(q, k, v) + + x = x.transpose(1, 2).reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x + +def get_attn_score(blk_class, x, frame_num, token_length, xpos=None): + x = blk_class.norm1(x) + + B, N, C = x.shape + qkv = blk_class.attn.qkv(x).reshape(B, N, 3, blk_class.attn.num_heads, C // blk_class.attn.num_heads) + + qkv = qkv.transpose(1, 3) + # q, k, v = unbind(qkv, 2) + q, k, v = [qkv[:,:,i] for i in range(3)] + q, k = blk_class.attn.q_norm(q).to(v.dtype), blk_class.attn.k_norm(k).to(v.dtype) + + if blk_class.attn.rope is not None: + q = blk_class.attn.rope(q, xpos) + k = blk_class.attn.rope(k, xpos) + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + + score = (q.permute(0, 2, 1, 3) * blk_class.attn.scale @ k.permute(0, 2, 1, 3).transpose(-2, -1)).sum(dim=1).reshape(B, frame_num, token_length, frame_num, token_length).mean(dim=[2, 4]).sum(-1) + + return score \ No newline at end of file diff --git a/prismatic/models/pi3/models/layers/block.py b/prismatic/models/pi3/models/layers/block.py new file mode 100644 index 0000000..c2c1f95 --- /dev/null +++ b/prismatic/models/pi3/models/layers/block.py @@ -0,0 +1,406 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import logging +import os +from typing import Callable, List, Any, Tuple, Dict +import warnings + +import torch +from torch import nn, Tensor + +from .attention import Attention, MemEffAttention, CrossAttentionRope, MemEffCrossAttentionRope, FlashAttentionRope +from ..dinov2.layers.drop_path import DropPath +from ..dinov2.layers.layer_scale import LayerScale +from ..dinov2.layers.mlp import Mlp + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import fmha, scaled_index_add, index_select_cat + + XFORMERS_AVAILABLE = True + # warnings.warn("xFormers is available (Block)") + else: + # warnings.warn("xFormers is disabled (Block)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + # warnings.warn("xFormers is not available (Block)") + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x))) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + else: + x_plus_residual = scaled_index_add( + x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): + outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list) + elif isinstance(x_or_x_list, list): + if not XFORMERS_AVAILABLE: + raise AssertionError("xFormers is required for using nested tensors") + return self.forward_nested(x_or_x_list) + else: + raise AssertionError + +class BlockRope(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + qk_norm: bool=False, + rope=None + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + qk_norm=qk_norm, + rope=rope + ) + + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor, xpos=None) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x), xpos=xpos)) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +class CrossBlockRope(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + cross_attn_class: Callable[..., nn.Module] = CrossAttentionRope, + ffn_layer: Callable[..., nn.Module] = Mlp, + init_values=None, + qk_norm: bool=False, + rope=None + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + rope=rope, + qk_norm=qk_norm + ) + + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.ls_y = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.norm2 = norm_layer(dim) + self.norm_y = norm_layer(dim) + self.cross_attn = cross_attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + rope=rope, + qk_norm=qk_norm + ) + + self.norm3 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + bias=ffn_bias, + ) + + def forward(self, x: Tensor, y: Tensor, xpos=None, ypos=None) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x), xpos=xpos)) + + def cross_attn_residual_func(x: Tensor, y: Tensor) -> Tensor: + return self.ls_y(self.cross_attn(self.norm2(x), y, y, qpos=xpos, kpos=ypos)) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm3(x))) + + x = x + attn_residual_func(x) + y_ = self.norm_y(y) + x = x + cross_attn_residual_func(x, y_) + x = x + ffn_residual_func(x) + + return x \ No newline at end of file diff --git a/prismatic/models/pi3/models/layers/camera_head.py b/prismatic/models/pi3/models/layers/camera_head.py new file mode 100644 index 0000000..7d844f7 --- /dev/null +++ b/prismatic/models/pi3/models/layers/camera_head.py @@ -0,0 +1,93 @@ +import torch +import torch.nn as nn +from copy import deepcopy +import torch.nn.functional as F + +# code adapted from 'https://github.com/nianticlabs/marepo/blob/9a45e2bb07e5bb8cb997620088d352b439b13e0e/transformer/transformer.py#L172' +class ResConvBlock(nn.Module): + """ + 1x1 convolution residual block + """ + def __init__(self, in_channels, out_channels): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.head_skip = nn.Identity() if self.in_channels == self.out_channels else nn.Conv2d(self.in_channels, self.out_channels, 1, 1, 0) + # self.res_conv1 = nn.Conv2d(self.in_channels, self.out_channels, 1, 1, 0) + # self.res_conv2 = nn.Conv2d(self.out_channels, self.out_channels, 1, 1, 0) + # self.res_conv3 = nn.Conv2d(self.out_channels, self.out_channels, 1, 1, 0) + + # change 1x1 convolution to linear + self.res_conv1 = nn.Linear(self.in_channels, self.out_channels) + self.res_conv2 = nn.Linear(self.out_channels, self.out_channels) + self.res_conv3 = nn.Linear(self.out_channels, self.out_channels) + + def forward(self, res): + x = F.relu(self.res_conv1(res)) + x = F.relu(self.res_conv2(x)) + x = F.relu(self.res_conv3(x)) + res = self.head_skip(res) + x + return res + +class CameraHead(nn.Module): + def __init__(self, dim=512): + super().__init__() + output_dim = dim + self.res_conv = nn.ModuleList([deepcopy(ResConvBlock(output_dim, output_dim)) + for _ in range(2)]) + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.more_mlps = nn.Sequential( + nn.Linear(output_dim,output_dim), + nn.ReLU(), + nn.Linear(output_dim,output_dim), + nn.ReLU() + ) + self.fc_t = nn.Linear(output_dim, 3) + self.fc_rot = nn.Linear(output_dim, 9) + + def forward(self, feat, patch_h, patch_w): + BN, hw, c = feat.shape + + for i in range(2): + feat = self.res_conv[i](feat) + + # feat = self.avgpool(feat) + feat = self.avgpool(feat.permute(0, 2, 1).reshape(BN, -1, patch_h, patch_w).contiguous()) ########## + feat = feat.view(feat.size(0), -1) + + feat = self.more_mlps(feat) # [B, D_] + with torch.amp.autocast(device_type='cuda', enabled=False): + out_t = self.fc_t(feat.float()) # [B,3] + out_r = self.fc_rot(feat.float()) # [B,9] + pose = self.convert_pose_to_4x4(BN, out_r, out_t, feat.device) + + return pose + + def convert_pose_to_4x4(self, B, out_r, out_t, device): + out_r = self.svd_orthogonalize(out_r) # [N,3,3] + pose = torch.zeros((B, 4, 4), device=device) + pose[:, :3, :3] = out_r + pose[:, :3, 3] = out_t + pose[:, 3, 3] = 1. + return pose + + def svd_orthogonalize(self, m): + """Convert 9D representation to SO(3) using SVD orthogonalization. + + Args: + m: [BATCH, 3, 3] 3x3 matrices. + + Returns: + [BATCH, 3, 3] SO(3) rotation matrices. + """ + if m.dim() < 3: + m = m.reshape((-1, 3, 3)) + m_transpose = torch.transpose(torch.nn.functional.normalize(m, p=2, dim=-1), dim0=-1, dim1=-2) + u, s, v = torch.svd(m_transpose) + det = torch.det(torch.matmul(v, u.transpose(-2, -1))) + # Check orientation reflection. + r = torch.matmul( + torch.cat([v[:, :, :-1], v[:, :, -1:] * det.view(-1, 1, 1)], dim=2), + u.transpose(-2, -1) + ) + return r \ No newline at end of file diff --git a/prismatic/models/pi3/models/layers/pos_embed.py b/prismatic/models/pi3/models/layers/pos_embed.py new file mode 100644 index 0000000..e27ea0f --- /dev/null +++ b/prismatic/models/pi3/models/layers/pos_embed.py @@ -0,0 +1,174 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + + +# -------------------------------------------------------- +# Position embedding utils +# -------------------------------------------------------- + + + +import numpy as np + +import torch + +# -------------------------------------------------------- +# 2D sine-cosine position embedding +# References: +# MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py +# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py +# MoCo v3: https://github.com/facebookresearch/moco-v3 +# -------------------------------------------------------- +def get_2d_sincos_pos_embed(embed_dim, grid_size, n_cls_token=0): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [n_cls_token+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if n_cls_token>0: + pos_embed = np.concatenate([np.zeros([n_cls_token, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=float) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +# -------------------------------------------------------- +# Interpolate position embeddings for high-resolution +# References: +# MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- +def interpolate_pos_embed(model, checkpoint_model): + if 'pos_embed' in checkpoint_model: + pos_embed_checkpoint = checkpoint_model['pos_embed'] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed.num_patches + num_extra_tokens = model.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches ** 0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model['pos_embed'] = new_pos_embed + + +#---------------------------------------------------------- +# RoPE2D: RoPE implementation in 2D +#---------------------------------------------------------- + +try: + from models.curope import cuRoPE2D + RoPE2D = cuRoPE2D +except ImportError: + print('Warning, cannot find cuda-compiled version of RoPE2D, using a slow pytorch version instead') + + class RoPE2D(torch.nn.Module): + + def __init__(self, freq=100.0, F0=1.0): + super().__init__() + self.base = freq + self.F0 = F0 + self.cache = {} + + def get_cos_sin(self, D, seq_len, device, dtype): + if (D,seq_len,device,dtype) not in self.cache: + inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D)) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype) + freqs = torch.cat((freqs, freqs), dim=-1) + cos = freqs.cos() # (Seq, Dim) + sin = freqs.sin() + self.cache[D,seq_len,device,dtype] = (cos,sin) + return self.cache[D,seq_len,device,dtype] + + @staticmethod + def rotate_half(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rope1d(self, tokens, pos1d, cos, sin): + assert pos1d.ndim==2 + cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :] + sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :] + return (tokens * cos) + (self.rotate_half(tokens) * sin) + + def forward(self, tokens, positions): + """ + input: + * tokens: batch_size x nheads x ntokens x dim + * positions: batch_size x ntokens x 2 (y and x position of each token) + output: + * tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim) + """ + assert tokens.size(3)%2==0, "number of dimensions should be a multiple of two" + D = tokens.size(3) // 2 + assert positions.ndim==3 and positions.shape[-1] == 2 # Batch, Seq, 2 + cos, sin = self.get_cos_sin(D, int(positions.max())+1, tokens.device, tokens.dtype) + # split features into two along the feature dimension, and apply rope1d on each half + y, x = tokens.chunk(2, dim=-1) + y = self.apply_rope1d(y, positions[:,:,0], cos, sin) + x = self.apply_rope1d(x, positions[:,:,1], cos, sin) + tokens = torch.cat((y, x), dim=-1) + return tokens + +# patch embedding +class PositionGetter(object): + """ return positions of patches """ + + def __init__(self): + self.cache_positions = {} + + def __call__(self, b, h, w, device): + if not (h,w) in self.cache_positions: + x = torch.arange(w, device=device) + y = torch.arange(h, device=device) + self.cache_positions[h,w] = torch.cartesian_prod(y, x) # (h, w, 2) + pos = self.cache_positions[h,w].view(1, h*w, 2).expand(b, -1, 2).clone() + return pos \ No newline at end of file diff --git a/prismatic/models/pi3/models/layers/transformer_head.py b/prismatic/models/pi3/models/layers/transformer_head.py new file mode 100644 index 0000000..8b03892 --- /dev/null +++ b/prismatic/models/pi3/models/layers/transformer_head.py @@ -0,0 +1,81 @@ +from .attention import FlashAttentionRope +from .block import BlockRope +from ..dinov2.layers import Mlp +import torch.nn as nn +from functools import partial +from torch.utils.checkpoint import checkpoint +import torch.nn.functional as F + +class TransformerDecoder(nn.Module): + def __init__( + self, + in_dim, + out_dim, + dec_embed_dim=512, + depth=5, + dec_num_heads=8, + mlp_ratio=4, + rope=None, + need_project=True, + use_checkpoint=False, + ): + super().__init__() + + self.projects = nn.Linear(in_dim, dec_embed_dim) if need_project else nn.Identity() + self.use_checkpoint = use_checkpoint + + self.blocks = nn.ModuleList([ + BlockRope( + dim=dec_embed_dim, + num_heads=dec_num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=True, + proj_bias=True, + ffn_bias=True, + drop_path=0.0, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + act_layer=nn.GELU, + ffn_layer=Mlp, + init_values=None, + qk_norm=False, + # attn_class=MemEffAttentionRope, + attn_class=FlashAttentionRope, + rope=rope + ) for _ in range(depth)]) + + self.linear_out = nn.Linear(dec_embed_dim, out_dim) + + def forward(self, hidden, xpos=None): + hidden = self.projects(hidden) + for i, blk in enumerate(self.blocks): + if self.use_checkpoint and self.training: + hidden = checkpoint(blk, hidden, xpos=xpos, use_reentrant=False) + else: + hidden = blk(hidden, xpos=xpos) + out = self.linear_out(hidden) + return out + +class LinearPts3d (nn.Module): + """ + Linear head for dust3r + Each token outputs: - 16x16 3D points (+ confidence) + """ + + def __init__(self, patch_size, dec_embed_dim, output_dim=3,): + super().__init__() + self.patch_size = patch_size + + self.proj = nn.Linear(dec_embed_dim, (output_dim)*self.patch_size**2) + + def forward(self, decout, img_shape): + H, W = img_shape + tokens = decout[-1] + B, S, D = tokens.shape + + # extract 3D points + feat = self.proj(tokens) # B,S,D + feat = feat.transpose(-1, -2).view(B, -1, H//self.patch_size, W//self.patch_size) + feat = F.pixel_shuffle(feat, self.patch_size) # B,3,H,W + + # permute + norm depth + return feat.permute(0, 2, 3, 1) \ No newline at end of file diff --git a/prismatic/models/pi3/models/pi3.py b/prismatic/models/pi3/models/pi3.py new file mode 100644 index 0000000..917c6cc --- /dev/null +++ b/prismatic/models/pi3/models/pi3.py @@ -0,0 +1,216 @@ +import torch +import torch.nn as nn +from functools import partial +from copy import deepcopy + +from .dinov2.layers import Mlp +from ..utils.geometry import homogenize_points +from .layers.pos_embed import RoPE2D, PositionGetter +from .layers.block import BlockRope +from .layers.attention import FlashAttentionRope +from .layers.transformer_head import TransformerDecoder, LinearPts3d +from .layers.camera_head import CameraHead +from .dinov2.hub.backbones import dinov2_vitl14, dinov2_vitl14_reg +from huggingface_hub import PyTorchModelHubMixin + +class Pi3(nn.Module, PyTorchModelHubMixin): + def __init__( + self, + pos_type='rope100', + decoder_size='large', + ): + super().__init__() + + # ---------------------- + # Encoder + # ---------------------- + self.encoder = dinov2_vitl14_reg(pretrained=False) + self.patch_size = 14 + del self.encoder.mask_token + + # ---------------------- + # Positonal Encoding + # ---------------------- + self.pos_type = pos_type if pos_type is not None else 'none' + self.rope=None + if self.pos_type.startswith('rope'): # eg rope100 + if RoPE2D is None: raise ImportError("Cannot find cuRoPE2D, please install it following the README instructions") + freq = float(self.pos_type[len('rope'):]) + self.rope = RoPE2D(freq=freq) + self.position_getter = PositionGetter() + else: + raise NotImplementedError + + + # ---------------------- + # Decoder + # ---------------------- + enc_embed_dim = self.encoder.blocks[0].attn.qkv.in_features # 1024 + if decoder_size == 'small': + dec_embed_dim = 384 + dec_num_heads = 6 + mlp_ratio = 4 + dec_depth = 24 + elif decoder_size == 'base': + dec_embed_dim = 768 + dec_num_heads = 12 + mlp_ratio = 4 + dec_depth = 24 + elif decoder_size == 'large': + dec_embed_dim = 1024 + dec_num_heads = 16 + mlp_ratio = 4 + dec_depth = 36 + else: + raise NotImplementedError + self.decoder = nn.ModuleList([ + BlockRope( + dim=dec_embed_dim, + num_heads=dec_num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=True, + proj_bias=True, + ffn_bias=True, + drop_path=0.0, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + act_layer=nn.GELU, + ffn_layer=Mlp, + init_values=0.01, + qk_norm=True, + attn_class=FlashAttentionRope, + rope=self.rope + ) for _ in range(dec_depth)]) + self.dec_embed_dim = dec_embed_dim + + # ---------------------- + # Register_token + # ---------------------- + num_register_tokens = 5 + self.patch_start_idx = num_register_tokens + self.register_token = nn.Parameter(torch.randn(1, 1, num_register_tokens, self.dec_embed_dim)) + nn.init.normal_(self.register_token, std=1e-6) + + # ---------------------- + # Local Points Decoder + # ---------------------- + self.point_decoder = TransformerDecoder( + in_dim=2*self.dec_embed_dim, + dec_embed_dim=1024, + dec_num_heads=16, + out_dim=1024, + rope=self.rope, + ) + self.point_head = LinearPts3d(patch_size=14, dec_embed_dim=1024, output_dim=3) + + # ---------------------- + # Conf Decoder + # ---------------------- + self.conf_decoder = deepcopy(self.point_decoder) + self.conf_head = LinearPts3d(patch_size=14, dec_embed_dim=1024, output_dim=1) + + # ---------------------- + # Camera Pose Decoder + # ---------------------- + self.camera_decoder = TransformerDecoder( + in_dim=2*self.dec_embed_dim, + dec_embed_dim=1024, + dec_num_heads=16, # 8 + out_dim=512, + rope=self.rope, + use_checkpoint=False + ) + self.camera_head = CameraHead(dim=512) + + # For ImageNet Normalize + image_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) + image_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) + + self.register_buffer("image_mean", image_mean) + self.register_buffer("image_std", image_std) + + + def decode(self, hidden, N, H, W): + BN, hw, _ = hidden.shape + B = BN // N + + final_output = [] + + hidden = hidden.reshape(B*N, hw, -1) + + register_token = self.register_token.repeat(B, N, 1, 1).reshape(B*N, *self.register_token.shape[-2:]) + + # Concatenate special tokens with patch tokens + hidden = torch.cat([register_token, hidden], dim=1) + hw = hidden.shape[1] + + if self.pos_type.startswith('rope'): + pos = self.position_getter(B * N, H//self.patch_size, W//self.patch_size, hidden.device) + + if self.patch_start_idx > 0: + # do not use position embedding for special tokens (camera and register tokens) + # so set pos to 0 for the special tokens + pos = pos + 1 + pos_special = torch.zeros(B * N, self.patch_start_idx, 2).to(hidden.device).to(pos.dtype) + pos = torch.cat([pos_special, pos], dim=1) + + for i in range(len(self.decoder)): + blk = self.decoder[i] + + if i % 2 == 0: + pos = pos.reshape(B*N, hw, -1) + hidden = hidden.reshape(B*N, hw, -1) + else: + pos = pos.reshape(B, N*hw, -1) + hidden = hidden.reshape(B, N*hw, -1) + + hidden = blk(hidden, xpos=pos) + + if i+1 in [len(self.decoder)-1, len(self.decoder)]: + final_output.append(hidden.reshape(B*N, hw, -1)) + + return torch.cat([final_output[0], final_output[1]], dim=-1), pos.reshape(B*N, hw, -1) + + def forward(self, imgs): + imgs = (imgs - self.image_mean) / self.image_std + + B, N, _, H, W = imgs.shape + patch_h, patch_w = H // 14, W // 14 + + # encode by dinov2 + imgs = imgs.reshape(B*N, _, H, W) + hidden = self.encoder(imgs, is_training=True) + + if isinstance(hidden, dict): + hidden = hidden["x_norm_patchtokens"] + + hidden, pos = self.decode(hidden, N, H, W) + + point_hidden = self.point_decoder(hidden, xpos=pos) + conf_hidden = self.conf_decoder(hidden, xpos=pos) + camera_hidden = self.camera_decoder(hidden, xpos=pos) + + with torch.amp.autocast(device_type='cuda', enabled=False): + # local points + point_hidden = point_hidden.float() + ret = self.point_head([point_hidden[:, self.patch_start_idx:]], (H, W)).reshape(B, N, H, W, -1) + xy, z = ret.split([2, 1], dim=-1) + z = torch.exp(z) + local_points = torch.cat([xy * z, z], dim=-1) + + # confidence + conf_hidden = conf_hidden.float() + conf = self.conf_head([conf_hidden[:, self.patch_start_idx:]], (H, W)).reshape(B, N, H, W, -1) + + # camera + camera_hidden = camera_hidden.float() + camera_poses = self.camera_head(camera_hidden[:, self.patch_start_idx:], patch_h, patch_w).reshape(B, N, 4, 4) + + # unproject local points using camera poses + points = torch.einsum('bnij, bnhwj -> bnhwi', camera_poses, homogenize_points(local_points))[..., :3] + + return dict( + points=points, + local_points=local_points, + conf=conf, + camera_poses=camera_poses, + ) diff --git a/prismatic/models/pi3/utils/basic.py b/prismatic/models/pi3/utils/basic.py new file mode 100644 index 0000000..9ac7349 --- /dev/null +++ b/prismatic/models/pi3/utils/basic.py @@ -0,0 +1,223 @@ +import os +import os.path as osp +import math +import cv2 +from PIL import Image +import torch +from torchvision import transforms +from plyfile import PlyData, PlyElement +import numpy as np + +def load_images_as_tensor(path='data/truck', interval=1, PIXEL_LIMIT=255000): + """ + Loads images from a directory or video, resizes them to a uniform size, + then converts and stacks them into a single [N, 3, H, W] PyTorch tensor. + """ + sources = [] + + # --- 1. Load image paths or video frames --- + if osp.isdir(path): + print(f"Loading images from directory: {path}") + filenames = sorted([x for x in os.listdir(path) if x.lower().endswith(('.png', '.jpg', '.jpeg'))]) + for i in range(0, len(filenames), interval): + img_path = osp.join(path, filenames[i]) + try: + sources.append(Image.open(img_path).convert('RGB')) + except Exception as e: + print(f"Could not load image {filenames[i]}: {e}") + elif path.lower().endswith('.mp4'): + print(f"Loading frames from video: {path}") + cap = cv2.VideoCapture(path) + if not cap.isOpened(): raise IOError(f"Cannot open video file: {path}") + frame_idx = 0 + while True: + ret, frame = cap.read() + if not ret: break + if frame_idx % interval == 0: + rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + sources.append(Image.fromarray(rgb_frame)) + frame_idx += 1 + cap.release() + else: + raise ValueError(f"Unsupported path. Must be a directory or a .mp4 file: {path}") + + if not sources: + print("No images found or loaded.") + return torch.empty(0) + + print(f"Found {len(sources)} images/frames. Processing...") + + # --- 2. Determine a uniform target size for all images based on the first image --- + # This is necessary to ensure all tensors have the same dimensions for stacking. + first_img = sources[0] + W_orig, H_orig = first_img.size + scale = math.sqrt(PIXEL_LIMIT / (W_orig * H_orig)) if W_orig * H_orig > 0 else 1 + W_target, H_target = W_orig * scale, H_orig * scale + k, m = round(W_target / 14), round(H_target / 14) + while (k * 14) * (m * 14) > PIXEL_LIMIT: + if k / m > W_target / H_target: k -= 1 + else: m -= 1 + TARGET_W, TARGET_H = max(1, k) * 14, max(1, m) * 14 + print(f"All images will be resized to a uniform size: ({TARGET_W}, {TARGET_H})") + + # --- 3. Resize images and convert them to tensors in the [0, 1] range --- + tensor_list = [] + # Define a transform to convert a PIL Image to a CxHxW tensor and normalize to [0,1] + to_tensor_transform = transforms.ToTensor() + + for img_pil in sources: + try: + # Resize to the uniform target size + resized_img = img_pil.resize((TARGET_W, TARGET_H), Image.Resampling.LANCZOS) + # Convert to tensor + img_tensor = to_tensor_transform(resized_img) + tensor_list.append(img_tensor) + except Exception as e: + print(f"Error processing an image: {e}") + + if not tensor_list: + print("No images were successfully processed.") + return torch.empty(0) + + # --- 4. Stack the list of tensors into a single [N, C, H, W] batch tensor --- + return torch.stack(tensor_list, dim=0) + + +def tensor_to_pil(tensor): + """ + Converts a PyTorch tensor to a PIL image. Automatically moves the channel dimension + (if it has size 3) to the last axis before converting. + + Args: + tensor (torch.Tensor): Input tensor. Expected shape can be [C, H, W], [H, W, C], or [H, W]. + + Returns: + PIL.Image: The converted PIL image. + """ + if torch.is_tensor(tensor): + array = tensor.detach().cpu().numpy() + else: + array = tensor + + return array_to_pil(array) + + +def array_to_pil(array): + """ + Converts a NumPy array to a PIL image. Automatically: + - Squeezes dimensions of size 1. + - Moves the channel dimension (if it has size 3) to the last axis. + + Args: + array (np.ndarray): Input array. Expected shape can be [C, H, W], [H, W, C], or [H, W]. + + Returns: + PIL.Image: The converted PIL image. + """ + # Remove singleton dimensions + array = np.squeeze(array) + + # Ensure the array has the channel dimension as the last axis + if array.ndim == 3 and array.shape[0] == 3: # If the channel is the first axis + array = np.transpose(array, (1, 2, 0)) # Move channel to the last axis + + # Handle single-channel grayscale images + if array.ndim == 2: # [H, W] + return Image.fromarray((array * 255).astype(np.uint8), mode="L") + elif array.ndim == 3 and array.shape[2] == 3: # [H, W, C] with 3 channels + return Image.fromarray((array * 255).astype(np.uint8), mode="RGB") + else: + raise ValueError(f"Unsupported array shape for PIL conversion: {array.shape}") + + +def rotate_target_dim_to_last_axis(x, target_dim=3): + shape = x.shape + axis_to_move = -1 + # Iterate backwards to find the first occurrence from the end + # (which corresponds to the last dimension of size 3 in the original order). + for i in range(len(shape) - 1, -1, -1): + if shape[i] == target_dim: + axis_to_move = i + break + + # 2. If the axis is found and it's not already in the last position, move it. + if axis_to_move != -1 and axis_to_move != len(shape) - 1: + # Create the new dimension order. + dims_order = list(range(len(shape))) + dims_order.pop(axis_to_move) + dims_order.append(axis_to_move) + + # Use permute to reorder the dimensions. + ret = x.transpose(*dims_order) + else: + ret = x + + return ret + + +def write_ply( + xyz, + rgb=None, + path='output.ply', +) -> None: + if torch.is_tensor(xyz): + xyz = xyz.detach().cpu().numpy() + + if torch.is_tensor(rgb): + rgb = rgb.detach().cpu().numpy() + + if rgb is not None and rgb.max() > 1: + rgb = rgb / 255. + + xyz = rotate_target_dim_to_last_axis(xyz, 3) + xyz = xyz.reshape(-1, 3) + + if rgb is not None: + rgb = rotate_target_dim_to_last_axis(rgb, 3) + rgb = rgb.reshape(-1, 3) + + if rgb is None: + min_coord = np.min(xyz, axis=0) + max_coord = np.max(xyz, axis=0) + normalized_coord = (xyz - min_coord) / (max_coord - min_coord + 1e-8) + + hue = 0.7 * normalized_coord[:,0] + 0.2 * normalized_coord[:,1] + 0.1 * normalized_coord[:,2] + hsv = np.stack([hue, 0.9*np.ones_like(hue), 0.8*np.ones_like(hue)], axis=1) + + c = hsv[:,2:] * hsv[:,1:2] + x = c * (1 - np.abs( (hsv[:,0:1]*6) % 2 - 1 )) + m = hsv[:,2:] - c + + rgb = np.zeros_like(hsv) + cond = (0 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 1) + rgb[cond] = np.hstack([c[cond], x[cond], np.zeros_like(x[cond])]) + cond = (1 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 2) + rgb[cond] = np.hstack([x[cond], c[cond], np.zeros_like(x[cond])]) + cond = (2 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 3) + rgb[cond] = np.hstack([np.zeros_like(x[cond]), c[cond], x[cond]]) + cond = (3 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 4) + rgb[cond] = np.hstack([np.zeros_like(x[cond]), x[cond], c[cond]]) + cond = (4 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 5) + rgb[cond] = np.hstack([x[cond], np.zeros_like(x[cond]), c[cond]]) + cond = (5 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 6) + rgb[cond] = np.hstack([c[cond], np.zeros_like(x[cond]), x[cond]]) + rgb = (rgb + m) + + dtype = [ + ("x", "f4"), + ("y", "f4"), + ("z", "f4"), + ("nx", "f4"), + ("ny", "f4"), + ("nz", "f4"), + ("red", "u1"), + ("green", "u1"), + ("blue", "u1"), + ] + normals = np.zeros_like(xyz) + elements = np.empty(xyz.shape[0], dtype=dtype) + attributes = np.concatenate((xyz, normals, rgb * 255), axis=1) + elements[:] = list(map(tuple, attributes)) + vertex_element = PlyElement.describe(elements, "vertex") + ply_data = PlyData([vertex_element]) + ply_data.write(path) \ No newline at end of file diff --git a/prismatic/models/pi3/utils/debug.py b/prismatic/models/pi3/utils/debug.py new file mode 100644 index 0000000..f3da8f3 --- /dev/null +++ b/prismatic/models/pi3/utils/debug.py @@ -0,0 +1,63 @@ +import os +import json +import debugpy +import socket +import random + +def update_vscode_launch_file(host: str, port: int): + """Update the .vscode/launch.json file with the new host and port.""" + launch_file_path = ".vscode/launch.json" + # Desired configuration + new_config = { + "version": "0.2.0", + "configurations": [ + { + "name": "bash_debug", + "type": "debugpy", + "request": "attach", + "connect": { + "host": host, + "port": port + }, + "justMyCode": False + }, + ] + } + + # Ensure the .vscode directory exists + if not os.path.exists(".vscode"): + os.makedirs(".vscode") + + # Write the updated configuration to launch.json + with open(launch_file_path, "w") as f: + json.dump(new_config, f, indent=4) + print(f"Updated {launch_file_path} with host: {host} and port: {port}") + +def is_port_in_use(host, port): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + return s.connect_ex((host, port)) == 0 + +def setup_debug(is_main_process=True, max_retries=10, port_range=(10000, 20000)): + if is_main_process: + host = os.environ['SLURM_NODELIST'].split(',')[0] + + for _ in range(max_retries): + port = random.randint(*port_range) + try: + if is_port_in_use(host, port): + print(f"Port {port} is already in use, trying another...") + continue + + # 更新 launch.json + update_vscode_launch_file(host, port) + + print("master_addr = ", host) + debugpy.listen((host, port)) + print(f"Waiting for debugger attach at port {port}...", flush=True) + debugpy.wait_for_client() + print("Debugger attached", flush=True) + return + except Exception as e: + print(f"Failed to bind to port {port}: {e}") + + raise RuntimeError("Could not find a free port for debugpy after several attempts.") \ No newline at end of file diff --git a/prismatic/models/pi3/utils/geometry.py b/prismatic/models/pi3/utils/geometry.py new file mode 100644 index 0000000..515a36f --- /dev/null +++ b/prismatic/models/pi3/utils/geometry.py @@ -0,0 +1,375 @@ +import numpy as np +import torch +import torch.nn.functional as F + +def se3_inverse(T): + """ + Computes the inverse of a batch of SE(3) matrices. + T: Tensor of shape (B, 4, 4) + """ + if len(T.shape) == 2: + T = T[None] + unseq_flag = True + else: + unseq_flag = False + + if torch.is_tensor(T): + R = T[:, :3, :3] + t = T[:, :3, 3].unsqueeze(-1) + R_inv = R.transpose(-2, -1) + t_inv = -torch.matmul(R_inv, t) + T_inv = torch.cat([ + torch.cat([R_inv, t_inv], dim=-1), + torch.tensor([0, 0, 0, 1], device=T.device, dtype=T.dtype).repeat(T.shape[0], 1, 1) + ], dim=1) + else: + R = T[:, :3, :3] + t = T[:, :3, 3, np.newaxis] + + R_inv = np.swapaxes(R, -2, -1) + t_inv = -R_inv @ t + + bottom_row = np.zeros((T.shape[0], 1, 4), dtype=T.dtype) + bottom_row[:, :, 3] = 1 + + top_part = np.concatenate([R_inv, t_inv], axis=-1) + T_inv = np.concatenate([top_part, bottom_row], axis=1) + + if unseq_flag: + T_inv = T_inv[0] + return T_inv + +def get_pixel(H, W): + # get 2D pixels (u, v) for image_a in cam_a pixel space + u_a, v_a = np.meshgrid(np.arange(W), np.arange(H)) + # u_a = np.flip(u_a, axis=1) + # v_a = np.flip(v_a, axis=0) + pixels_a = np.stack([ + u_a.flatten() + 0.5, + v_a.flatten() + 0.5, + np.ones_like(u_a.flatten()) + ], axis=0) + + return pixels_a + +def depthmap_to_absolute_camera_coordinates(depthmap, camera_intrinsics, camera_pose, z_far=0, **kw): + """ + Args: + - depthmap (HxW array): + - camera_intrinsics: a 3x3 matrix + - camera_pose: a 4x3 or 4x4 cam2world matrix + Returns: + pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.""" + X_cam, valid_mask = depthmap_to_camera_coordinates(depthmap, camera_intrinsics) + if z_far > 0: + valid_mask = valid_mask & (depthmap < z_far) + + X_world = X_cam # default + if camera_pose is not None: + # R_cam2world = np.float32(camera_params["R_cam2world"]) + # t_cam2world = np.float32(camera_params["t_cam2world"]).squeeze() + R_cam2world = camera_pose[:3, :3] + t_cam2world = camera_pose[:3, 3] + + # Express in absolute coordinates (invalid depth values) + X_world = np.einsum("ik, vuk -> vui", R_cam2world, X_cam) + t_cam2world[None, None, :] + + return X_world, valid_mask + + +def depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_focal=None): + """ + Args: + - depthmap (HxW array): + - camera_intrinsics: a 3x3 matrix + Returns: + pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels. + """ + camera_intrinsics = np.float32(camera_intrinsics) + H, W = depthmap.shape + + # Compute 3D ray associated with each pixel + # Strong assumption: there are no skew terms + # assert camera_intrinsics[0, 1] == 0.0 + # assert camera_intrinsics[1, 0] == 0.0 + if pseudo_focal is None: + fu = camera_intrinsics[0, 0] + fv = camera_intrinsics[1, 1] + else: + assert pseudo_focal.shape == (H, W) + fu = fv = pseudo_focal + cu = camera_intrinsics[0, 2] + cv = camera_intrinsics[1, 2] + + u, v = np.meshgrid(np.arange(W), np.arange(H)) + z_cam = depthmap + x_cam = (u - cu) * z_cam / fu + y_cam = (v - cv) * z_cam / fv + X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32) + + # Mask for valid coordinates + valid_mask = (depthmap > 0.0) + # Invalid any depth > 80m + valid_mask = valid_mask + return X_cam, valid_mask + +def homogenize_points( + points, +): + """Convert batched points (xyz) to (xyz1).""" + return torch.cat([points, torch.ones_like(points[..., :1])], dim=-1) + + +def get_gt_warp(depth1, depth2, T_1to2, K1, K2, depth_interpolation_mode = 'bilinear', relative_depth_error_threshold = 0.05, H = None, W = None): + + if H is None: + B,H,W = depth1.shape + else: + B = depth1.shape[0] + with torch.no_grad(): + x1_n = torch.meshgrid( + *[ + torch.linspace( + -1 + 1 / n, 1 - 1 / n, n, device=depth1.device + ) + for n in (B, H, W) + ], + indexing = 'ij' + ) + x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(B, H * W, 2) + mask, x2 = warp_kpts( + x1_n.double(), + depth1.double(), + depth2.double(), + T_1to2.double(), + K1.double(), + K2.double(), + depth_interpolation_mode = depth_interpolation_mode, + relative_depth_error_threshold = relative_depth_error_threshold, + ) + prob = mask.float().reshape(B, H, W) + x2 = x2.reshape(B, H, W, 2) + return x2, prob + +@torch.no_grad() +def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, smooth_mask = False, return_relative_depth_error = False, depth_interpolation_mode = "bilinear", relative_depth_error_threshold = 0.05): + """Warp kpts0 from I0 to I1 with depth, K and Rt + Also check covisibility and depth consistency. + Depth is consistent if relative error < 0.2 (hard-coded). + # https://github.com/zju3dv/LoFTR/blob/94e98b695be18acb43d5d3250f52226a8e36f839/src/loftr/utils/geometry.py adapted from here + Args: + kpts0 (torch.Tensor): [N, L, 2] - , should be normalized in (-1,1) + depth0 (torch.Tensor): [N, H, W], + depth1 (torch.Tensor): [N, H, W], + T_0to1 (torch.Tensor): [N, 3, 4], + K0 (torch.Tensor): [N, 3, 3], + K1 (torch.Tensor): [N, 3, 3], + Returns: + calculable_mask (torch.Tensor): [N, L] + warped_keypoints0 (torch.Tensor): [N, L, 2] + """ + ( + n, + h, + w, + ) = depth0.shape + if depth_interpolation_mode == "combined": + # Inspired by approach in inloc, try to fill holes from bilinear interpolation by nearest neighbour interpolation + if smooth_mask: + raise NotImplementedError("Combined bilinear and NN warp not implemented") + valid_bilinear, warp_bilinear = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, + smooth_mask = smooth_mask, + return_relative_depth_error = return_relative_depth_error, + depth_interpolation_mode = "bilinear", + relative_depth_error_threshold = relative_depth_error_threshold) + valid_nearest, warp_nearest = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, + smooth_mask = smooth_mask, + return_relative_depth_error = return_relative_depth_error, + depth_interpolation_mode = "nearest-exact", + relative_depth_error_threshold = relative_depth_error_threshold) + nearest_valid_bilinear_invalid = (~valid_bilinear).logical_and(valid_nearest) + warp = warp_bilinear.clone() + warp[nearest_valid_bilinear_invalid] = warp_nearest[nearest_valid_bilinear_invalid] + valid = valid_bilinear | valid_nearest + return valid, warp + + + kpts0_depth = F.grid_sample(depth0[:, None], kpts0[:, :, None], mode = depth_interpolation_mode, align_corners=False)[ + :, 0, :, 0 + ] + kpts0 = torch.stack( + (w * (kpts0[..., 0] + 1) / 2, h * (kpts0[..., 1] + 1) / 2), dim=-1 + ) # [-1+1/h, 1-1/h] -> [0.5, h-0.5] + # Sample depth, get calculable_mask on depth != 0 + # nonzero_mask = kpts0_depth != 0 + # Sample depth, get calculable_mask on depth > 0 + nonzero_mask = kpts0_depth > 0 + + # Unproject + kpts0_h = ( + torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) + * kpts0_depth[..., None] + ) # (N, L, 3) + kpts0_n = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L) + kpts0_cam = kpts0_n + + # Rigid Transform + w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L) + w_kpts0_depth_computed = w_kpts0_cam[:, 2, :] + + # Project + w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3) + w_kpts0 = w_kpts0_h[:, :, :2] / ( + w_kpts0_h[:, :, [2]] + 1e-4 + ) # (N, L, 2), +1e-4 to avoid zero depth + + # Covisible Check + h, w = depth1.shape[1:3] + covisible_mask = ( + (w_kpts0[:, :, 0] > 0) + * (w_kpts0[:, :, 0] < w - 1) + * (w_kpts0[:, :, 1] > 0) + * (w_kpts0[:, :, 1] < h - 1) + ) + w_kpts0 = torch.stack( + (2 * w_kpts0[..., 0] / w - 1, 2 * w_kpts0[..., 1] / h - 1), dim=-1 + ) # from [0.5,h-0.5] -> [-1+1/h, 1-1/h] + # w_kpts0[~covisible_mask, :] = -5 # xd + + w_kpts0_depth = F.grid_sample( + depth1[:, None], w_kpts0[:, :, None], mode=depth_interpolation_mode, align_corners=False + )[:, 0, :, 0] + + relative_depth_error = ( + (w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth + ).abs() + if not smooth_mask: + consistent_mask = relative_depth_error < relative_depth_error_threshold + else: + consistent_mask = (-relative_depth_error/smooth_mask).exp() + valid_mask = nonzero_mask * covisible_mask * consistent_mask + if return_relative_depth_error: + return relative_depth_error, w_kpts0 + else: + return valid_mask, w_kpts0 + + +def geotrf(Trf, pts, ncol=None, norm=False): + """ Apply a geometric transformation to a list of 3-D points. + + H: 3x3 or 4x4 projection matrix (typically a Homography) + p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3) + + ncol: int. number of columns of the result (2 or 3) + norm: float. if != 0, the resut is projected on the z=norm plane. + + Returns an array of projected 2d points. + """ + assert Trf.ndim >= 2 + if isinstance(Trf, np.ndarray): + pts = np.asarray(pts) + elif isinstance(Trf, torch.Tensor): + pts = torch.as_tensor(pts, dtype=Trf.dtype) + + # adapt shape if necessary + output_reshape = pts.shape[:-1] + ncol = ncol or pts.shape[-1] + + # optimized code + if (isinstance(Trf, torch.Tensor) and isinstance(pts, torch.Tensor) and + Trf.ndim == 3 and pts.ndim == 4): + d = pts.shape[3] + if Trf.shape[-1] == d: + pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts) + elif Trf.shape[-1] == d + 1: + pts = torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) + Trf[:, None, None, :d, d] + else: + raise ValueError(f'bad shape, not ending with 3 or 4, for {pts.shape=}') + else: + if Trf.ndim >= 3: + n = Trf.ndim - 2 + assert Trf.shape[:n] == pts.shape[:n], 'batch size does not match' + Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1]) + + if pts.ndim > Trf.ndim: + # Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d) + pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1]) + elif pts.ndim == 2: + # Trf == (B,d,d) & pts == (B,d) --> (B, 1, d) + pts = pts[:, None, :] + + if pts.shape[-1] + 1 == Trf.shape[-1]: + Trf = Trf.swapaxes(-1, -2) # transpose Trf + pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :] + elif pts.shape[-1] == Trf.shape[-1]: + Trf = Trf.swapaxes(-1, -2) # transpose Trf + pts = pts @ Trf + else: + pts = Trf @ pts.T + if pts.ndim >= 2: + pts = pts.swapaxes(-1, -2) + + if norm: + pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG + if norm != 1: + pts *= norm + + res = pts[..., :ncol].reshape(*output_reshape, ncol) + return res + + +def inv(mat): + """ Invert a torch or numpy matrix + """ + if isinstance(mat, torch.Tensor): + return torch.linalg.inv(mat) + if isinstance(mat, np.ndarray): + return np.linalg.inv(mat) + raise ValueError(f'bad matrix type = {type(mat)}') + +def opencv_camera_to_plucker(poses, K, H, W): + device = poses.device + B = poses.shape[0] + + pixel = torch.from_numpy(get_pixel(H, W).astype(np.float32)).to(device).T.reshape(H, W, 3)[None].repeat(B, 1, 1, 1) # (3, H, W) + pixel = torch.einsum('bij, bhwj -> bhwi', torch.inverse(K), pixel) + ray_directions = torch.einsum('bij, bhwj -> bhwi', poses[..., :3, :3], pixel) + + ray_origins = poses[..., :3, 3][:, None, None].repeat(1, H, W, 1) + + ray_directions = ray_directions / ray_directions.norm(dim=-1, keepdim=True) + plucker_normal = torch.cross(ray_origins, ray_directions, dim=-1) + plucker_ray = torch.cat([ray_directions, plucker_normal], dim=-1) + + return plucker_ray + + +def depth_edge(depth: torch.Tensor, atol: float = None, rtol: float = None, kernel_size: int = 3, mask: torch.Tensor = None) -> torch.BoolTensor: + """ + Compute the edge mask of a depth map. The edge is defined as the pixels whose neighbors have a large difference in depth. + + Args: + depth (torch.Tensor): shape (..., height, width), linear depth map + atol (float): absolute tolerance + rtol (float): relative tolerance + + Returns: + edge (torch.Tensor): shape (..., height, width) of dtype torch.bool + """ + shape = depth.shape + depth = depth.reshape(-1, 1, *shape[-2:]) + if mask is not None: + mask = mask.reshape(-1, 1, *shape[-2:]) + + if mask is None: + diff = (F.max_pool2d(depth, kernel_size, stride=1, padding=kernel_size // 2) + F.max_pool2d(-depth, kernel_size, stride=1, padding=kernel_size // 2)) + else: + diff = (F.max_pool2d(torch.where(mask, depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2) + F.max_pool2d(torch.where(mask, -depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2)) + + edge = torch.zeros_like(depth, dtype=torch.bool) + if atol is not None: + edge |= diff > atol + if rtol is not None: + edge |= (diff / depth).nan_to_num_() > rtol + edge = edge.reshape(*shape) + return edge \ No newline at end of file diff --git a/prismatic/models/pi3_loader.py b/prismatic/models/pi3_loader.py new file mode 100644 index 0000000..286577a --- /dev/null +++ b/prismatic/models/pi3_loader.py @@ -0,0 +1,257 @@ +""" +pi3_loader.py + +Implementations of pi3_loader, loading pi3 model which predicts pointclouds and camera extrinsics from images. +""" +from typing import Tuple, List, Optional, Dict, Union, Type +from pathlib import Path +from termcolor import cprint + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +from prismatic.models.pi3.models.pi3 import Pi3 +from prismatic.overwatch import initialize_overwatch + +overwatch = initialize_overwatch(__name__) + +def load_pc_model(pi3_path: Union[str, Path]) -> Pi3: + overwatch.info(f"Loading PC model from {pi3_path}") + if pi3_path is not None: + pc_model = Pi3.from_pretrained(Path(pi3_path) if isinstance(pi3_path, str) else pi3_path) + overwatch.info(f"PC model Loaded Successfully from loacal dir: {pi3_path}") + else: + raise ValueError("Please provide a valid path or repo id to a PC model") + + return pc_model + +# Pointcloud Encoder +def meanpool(x, dim=-1, keepdim=False): + out = x.mean(dim=dim, keepdim=keepdim) + return out + +def maxpool(x, dim=-1, keepdim=False): + out = x.max(dim=dim, keepdim=keepdim).values + return out + +class MultiStagePointNetEncoder(nn.Module): + def __init__(self, h_dim=128, out_channels=128, num_layers=4, **kwargs): + super().__init__() + + self.h_dim = h_dim + self.out_channels = out_channels + self.num_layers = num_layers + + self.act = nn.LeakyReLU(negative_slope=0.0, inplace=False) + + self.conv_in = nn.Conv1d(3, h_dim, kernel_size=1) + self.layers, self.global_layers = nn.ModuleList(), nn.ModuleList() + for i in range(self.num_layers): + self.layers.append(nn.Conv1d(h_dim, h_dim, kernel_size=1)) + self.global_layers.append(nn.Conv1d(h_dim * 2, h_dim, kernel_size=1)) + self.conv_out = nn.Conv1d(h_dim * self.num_layers, out_channels, kernel_size=1) + + def forward(self, x): + x = x.transpose(1, 2) # [B, N, 3] --> [B, 3, N] + y = self.act(self.conv_in(x)) + feat_list = [] + for i in range(self.num_layers): + y = self.act(self.layers[i](y)) + y_global = y.max(-1, keepdim=True).values + y = torch.cat([y, y_global.expand_as(y)], dim=1) + y = self.act(self.global_layers[i](y)) + feat_list.append(y) + x = torch.cat(feat_list, dim=1) + x = self.conv_out(x) + + x_global = x.max(-1).values + + return x_global + +def shuffle_point_numpy(point_cloud): + B, N, C = point_cloud.shape + indices = np.random.permutation(N) + return point_cloud[:, indices] + +def pad_point_numpy(point_cloud, num_points): + B, N, C = point_cloud.shape + if num_points > N: + num_pad = num_points - N + pad_points = np.zeros((B, num_pad, C)) + point_cloud = np.concatenate([point_cloud, pad_points], axis=1) + point_cloud = shuffle_point_numpy(point_cloud) + return point_cloud + +def uniform_sampling_numpy(point_cloud, num_points): + B, N, C = point_cloud.shape + # padd if num_points > N + if num_points > N: + return pad_point_numpy(point_cloud, num_points) + + # random sampling + indices = np.random.permutation(N)[:num_points] + sampled_points = point_cloud[:, indices] + return sampled_points + +def shuffle_point_torch(point_cloud): + B, N, C = point_cloud.shape + indices = torch.randperm(N) + return point_cloud[:, indices] + +def pad_point_torch(point_cloud, num_points): + B, N, C = point_cloud.shape + device = point_cloud.device + if num_points > N: + num_pad = num_points - N + pad_points = torch.zeros(B, num_pad, C).to(device) + point_cloud = torch.cat([point_cloud, pad_points], dim=1) + point_cloud = shuffle_point_torch(point_cloud) + return point_cloud + +def uniform_sampling_torch(point_cloud, num_points): + B, N, C = point_cloud.shape + device = point_cloud.device + # padd if num_points > N + if num_points == N: + return point_cloud + if num_points > N: + return pad_point_torch(point_cloud, num_points) + + # random sampling + indices = torch.randperm(N)[:num_points] + sampled_points = point_cloud[:, indices] + return sampled_points + +def create_mlp( + input_dim: int, + output_dim: int, + net_arch: List[int], + activation_fn: Type[nn.Module] = nn.ReLU, + squash_output: bool = False, +) -> List[nn.Module]: + """ + Create a multi layer perceptron (MLP), which is + a collection of fully-connected layers each followed by an activation function. + + :param input_dim: Dimension of the input vector + :param output_dim: + :param net_arch: Architecture of the neural net + It represents the number of units per layer. + The length of this list is the number of layers. + :param activation_fn: The activation function + to use after each layer. + :param squash_output: Whether to squash the output using a Tanh + activation function + :return: + """ + + if len(net_arch) > 0: + modules = [nn.Linear(input_dim, net_arch[0]), activation_fn()] + else: + modules = [] + + for idx in range(len(net_arch) - 1): + modules.append(nn.Linear(net_arch[idx], net_arch[idx + 1])) + modules.append(activation_fn()) + + if output_dim > 0: + last_layer_dim = net_arch[-1] if len(net_arch) > 0 else input_dim + modules.append(nn.Linear(last_layer_dim, output_dim)) + if squash_output: + modules.append(nn.Tanh()) + return modules + + +class iDP3Encoder(nn.Module): + """ + 修改后的 iDP3Encoder,只处理点云数据,删除了所有 state 相关的部分 + """ + def __init__(self, + observation_space: Dict, + pointcloud_encoder_cfg=None, + use_pc_color=False, + pointnet_type='multi_stage_pointnet', + point_downsample=True, + ): + super().__init__() + self.point_cloud_key = 'point_cloud' + self.n_output_channels = pointcloud_encoder_cfg.out_channels + + self.point_cloud_shape = observation_space[self.point_cloud_key] + self.num_points = pointcloud_encoder_cfg.num_points # 4096 + + print(f"[iDP3Encoder] point cloud shape: {self.point_cloud_shape}") + + self.use_pc_color = use_pc_color + self.pointnet_type = pointnet_type + + self.downsample = point_downsample + if self.downsample: + self.point_preprocess = uniform_sampling_torch + else: + self.point_preprocess = nn.Identity() + + if pointnet_type == "multi_stage_pointnet": + self.extractor = MultiStagePointNetEncoder( + out_channels=pointcloud_encoder_cfg.out_channels + ) + else: + raise NotImplementedError(f"pointnet_type: {pointnet_type}") + + print(f"[iDP3Encoder] output dim: {self.n_output_channels}") + + def forward(self, observations: Dict) -> torch.Tensor: + points = observations[self.point_cloud_key] + assert len(points.shape) == 3, f"point cloud shape: {points.shape}, length should be 3" + + # 下采样点云 + if self.downsample: + points = self.point_preprocess(points, self.num_points) + + # 提取点云特征 + pn_feat = self.extractor(points) # B * out_channels + + return pn_feat + + def output_shape(self): + return self.n_output_channels + + +class PointCloudEncoderConfig: + def __init__(self, out_channels=128, num_points=4096): + self.out_channels = out_channels + self.num_points = num_points + +if __name__ == "__main__": + + pc_model = load_pc_model("/home/ruihengwang/vla/VLA-Adapter/pretrained_models/pi3_checkpoint") + batch_size = 2 + out_channels = 128 + num_points = 4096 + observation_space = { + 'point_cloud': (num_points, 3), + } + pointcloud_encoder_cfg = PointCloudEncoderConfig( + out_channels=out_channels, + num_points=num_points + ) + encoder = iDP3Encoder( + observation_space=observation_space, + pointcloud_encoder_cfg=pointcloud_encoder_cfg, + point_downsample=True + ) + encoder.eval() + point_cloud = torch.randn(batch_size, num_points, 3) + + observations = { + 'point_cloud': point_cloud + } + with torch.no_grad(): + output = encoder(observations) + print(f"Input shape: {observations['point_cloud'].shape}") + print(f"\n输出特征形状: {output.shape}") + print(f"输出特征范围: [{output.min():.3f}, {output.max():.3f}]") + print(f"输出维度: {encoder.output_shape()}") diff --git a/prismatic/models/vlas/__pycache__/__init__.cpython-310.pyc b/prismatic/models/vlas/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 9a05ffb..0000000 Binary files a/prismatic/models/vlas/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/prismatic/models/vlas/__pycache__/openvla.cpython-310.pyc b/prismatic/models/vlas/__pycache__/openvla.cpython-310.pyc deleted file mode 100644 index 7437f32..0000000 Binary files a/prismatic/models/vlas/__pycache__/openvla.cpython-310.pyc and /dev/null differ diff --git a/prismatic/models/vlms/__pycache__/__init__.cpython-310.pyc b/prismatic/models/vlms/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index dba07bf..0000000 Binary files a/prismatic/models/vlms/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/prismatic/models/vlms/__pycache__/base_vlm.cpython-310.pyc b/prismatic/models/vlms/__pycache__/base_vlm.cpython-310.pyc deleted file mode 100644 index 996f58d..0000000 Binary files a/prismatic/models/vlms/__pycache__/base_vlm.cpython-310.pyc and /dev/null differ diff --git a/prismatic/models/vlms/__pycache__/prismatic.cpython-310.pyc b/prismatic/models/vlms/__pycache__/prismatic.cpython-310.pyc deleted file mode 100644 index fa8f1d1..0000000 Binary files a/prismatic/models/vlms/__pycache__/prismatic.cpython-310.pyc and /dev/null differ diff --git a/prismatic/training/strategies/__pycache__/__init__.cpython-310.pyc b/prismatic/training/strategies/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 5c96e20..0000000 Binary files a/prismatic/training/strategies/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/prismatic/training/strategies/__pycache__/base_strategy.cpython-310.pyc b/prismatic/training/strategies/__pycache__/base_strategy.cpython-310.pyc deleted file mode 100644 index daf7ed2..0000000 Binary files a/prismatic/training/strategies/__pycache__/base_strategy.cpython-310.pyc and /dev/null differ diff --git a/prismatic/training/strategies/__pycache__/ddp.cpython-310.pyc b/prismatic/training/strategies/__pycache__/ddp.cpython-310.pyc deleted file mode 100644 index e14d78f..0000000 Binary files a/prismatic/training/strategies/__pycache__/ddp.cpython-310.pyc and /dev/null differ diff --git a/prismatic/training/strategies/__pycache__/fsdp.cpython-310.pyc b/prismatic/training/strategies/__pycache__/fsdp.cpython-310.pyc deleted file mode 100644 index 6109ad4..0000000 Binary files a/prismatic/training/strategies/__pycache__/fsdp.cpython-310.pyc and /dev/null differ diff --git a/prismatic/vla/datasets/__pycache__/__init__.cpython-310.pyc b/prismatic/vla/datasets/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index da9750c..0000000 Binary files a/prismatic/vla/datasets/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/prismatic/vla/datasets/__pycache__/datasets.cpython-310.pyc b/prismatic/vla/datasets/__pycache__/datasets.cpython-310.pyc deleted file mode 100644 index 0eab071..0000000 Binary files a/prismatic/vla/datasets/__pycache__/datasets.cpython-310.pyc and /dev/null differ diff --git a/prismatic/vla/datasets/rlds/__pycache__/__init__.cpython-310.pyc b/prismatic/vla/datasets/rlds/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 444a3ae..0000000 Binary files a/prismatic/vla/datasets/rlds/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/prismatic/vla/datasets/rlds/__pycache__/dataset.cpython-310.pyc b/prismatic/vla/datasets/rlds/__pycache__/dataset.cpython-310.pyc deleted file mode 100644 index b27339b..0000000 Binary files a/prismatic/vla/datasets/rlds/__pycache__/dataset.cpython-310.pyc and /dev/null differ diff --git a/prismatic/vla/datasets/rlds/__pycache__/obs_transforms.cpython-310.pyc b/prismatic/vla/datasets/rlds/__pycache__/obs_transforms.cpython-310.pyc deleted file mode 100644 index 01b6a49..0000000 Binary files a/prismatic/vla/datasets/rlds/__pycache__/obs_transforms.cpython-310.pyc and /dev/null differ diff --git a/prismatic/vla/datasets/rlds/__pycache__/traj_transforms.cpython-310.pyc b/prismatic/vla/datasets/rlds/__pycache__/traj_transforms.cpython-310.pyc deleted file mode 100644 index 49363bb..0000000 Binary files a/prismatic/vla/datasets/rlds/__pycache__/traj_transforms.cpython-310.pyc and /dev/null differ diff --git a/prismatic/vla/datasets/rlds/oxe/__pycache__/__init__.cpython-310.pyc b/prismatic/vla/datasets/rlds/oxe/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index a2a9824..0000000 Binary files a/prismatic/vla/datasets/rlds/oxe/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/prismatic/vla/datasets/rlds/oxe/__pycache__/configs.cpython-310.pyc b/prismatic/vla/datasets/rlds/oxe/__pycache__/configs.cpython-310.pyc deleted file mode 100644 index 4770256..0000000 Binary files a/prismatic/vla/datasets/rlds/oxe/__pycache__/configs.cpython-310.pyc and /dev/null differ diff --git a/prismatic/vla/datasets/rlds/oxe/__pycache__/materialize.cpython-310.pyc b/prismatic/vla/datasets/rlds/oxe/__pycache__/materialize.cpython-310.pyc deleted file mode 100644 index 6949d86..0000000 Binary files a/prismatic/vla/datasets/rlds/oxe/__pycache__/materialize.cpython-310.pyc and /dev/null differ diff --git a/prismatic/vla/datasets/rlds/oxe/__pycache__/mixtures.cpython-310.pyc b/prismatic/vla/datasets/rlds/oxe/__pycache__/mixtures.cpython-310.pyc deleted file mode 100644 index 53d79d4..0000000 Binary files a/prismatic/vla/datasets/rlds/oxe/__pycache__/mixtures.cpython-310.pyc and /dev/null differ diff --git a/prismatic/vla/datasets/rlds/oxe/__pycache__/transforms.cpython-310.pyc b/prismatic/vla/datasets/rlds/oxe/__pycache__/transforms.cpython-310.pyc deleted file mode 100644 index 742579a..0000000 Binary files a/prismatic/vla/datasets/rlds/oxe/__pycache__/transforms.cpython-310.pyc and /dev/null differ diff --git a/prismatic/vla/datasets/rlds/oxe/configs.py b/prismatic/vla/datasets/rlds/oxe/configs.py index 3222e02..1e550a2 100644 --- a/prismatic/vla/datasets/rlds/oxe/configs.py +++ b/prismatic/vla/datasets/rlds/oxe/configs.py @@ -177,6 +177,13 @@ class ActionEncoding(IntEnum): "state_encoding": StateEncoding.POS_EULER, "action_encoding": ActionEncoding.EEF_POS, }, + "calvin_abc_rlds": { + "image_obs_keys": {"primary": "rgb_static", "secondary": None, "wrist": "rgb_gripper"}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, "columbia_cairlab_pusht_real": { "image_obs_keys": { "primary": "image", diff --git a/prismatic/vla/datasets/rlds/oxe/utils/__pycache__/droid_utils.cpython-310.pyc b/prismatic/vla/datasets/rlds/oxe/utils/__pycache__/droid_utils.cpython-310.pyc deleted file mode 100644 index 9b77c30..0000000 Binary files a/prismatic/vla/datasets/rlds/oxe/utils/__pycache__/droid_utils.cpython-310.pyc and /dev/null differ diff --git a/prismatic/vla/datasets/rlds/utils/__pycache__/__init__.cpython-310.pyc b/prismatic/vla/datasets/rlds/utils/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index b0c94df..0000000 Binary files a/prismatic/vla/datasets/rlds/utils/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/prismatic/vla/datasets/rlds/utils/__pycache__/data_utils.cpython-310.pyc b/prismatic/vla/datasets/rlds/utils/__pycache__/data_utils.cpython-310.pyc deleted file mode 100644 index a07036c..0000000 Binary files a/prismatic/vla/datasets/rlds/utils/__pycache__/data_utils.cpython-310.pyc and /dev/null differ diff --git a/prismatic/vla/datasets/rlds/utils/__pycache__/goal_relabeling.cpython-310.pyc b/prismatic/vla/datasets/rlds/utils/__pycache__/goal_relabeling.cpython-310.pyc deleted file mode 100644 index d8e365f..0000000 Binary files a/prismatic/vla/datasets/rlds/utils/__pycache__/goal_relabeling.cpython-310.pyc and /dev/null differ diff --git a/prismatic/vla/datasets/rlds/utils/__pycache__/task_augmentation.cpython-310.pyc b/prismatic/vla/datasets/rlds/utils/__pycache__/task_augmentation.cpython-310.pyc deleted file mode 100644 index 42ee8ea..0000000 Binary files a/prismatic/vla/datasets/rlds/utils/__pycache__/task_augmentation.cpython-310.pyc and /dev/null differ diff --git a/run.sh b/run.sh new file mode 100644 index 0000000..6d9429f --- /dev/null +++ b/run.sh @@ -0,0 +1,37 @@ +# data_name=calvin_abc_rlds +data_name=libero_10_no_noops +export HF_HUB_OFFLINE=1 +export TRANSFORMERS_OFFLINE=1 +export HF_DATASETS_OFFLINE=1 +CUDA_VISIBLE_DEVICES=4,5 torchrun --standalone --nnodes 1 --nproc-per-node 2 vla-scripts/finetune.py \ + --vlm_path pretrained_models/prism-qwen25-extra-dinosiglip-224px-0_5b \ + --config_file_path pretrained_models/configs \ + --data_root_dir data/libero \ + --dataset_name $data_name \ + --run_root_dir outputs \ + --use_film False \ + --num_images_in_input 2 \ + --use_proprio True \ + --use_lora True \ + --use_fz False \ + --use_minivlm True \ + --image_aug True \ + --num_steps_before_decay 200000 \ + --max_steps 200005 \ + --save_freq 10000 \ + --save_latest_checkpoint_only False \ + --merge_lora_during_training True \ + --batch_size 8 \ + --grad_accumulation_steps 2 \ + --learning_rate 2e-4 \ + --lora_rank 64 \ + --use_pro_version True \ + --wandb_entity "my-wandb-org" \ + --wandb_project "$data_name" \ + --use_3d True \ + --inject_layers all \ + --run_id_note VLA-Adapter--$data_name--$(date "+%Y_%m_%d_%H_%M_%S") \ + # --resume True \ + # --resum_vla_path outputs/configs+calvin_abc_rlds+b16+lr-0.0002+lora-r64+dropout-0.0--image_aug--VLA-Adapter--calvin_abc_rlds--2025_10_12_19_33_45--110000_chkpt \ + # --resume_step 110000 \ + # > experiments/logs/Train--$data_name--$(date "+%Y_%m_%d_%H_%M_%S").log 2>&1 & \ No newline at end of file diff --git a/vla-scripts/finetune.py b/vla-scripts/finetune.py index 03263c1..d1d6f88 100644 --- a/vla-scripts/finetune.py +++ b/vla-scripts/finetune.py @@ -3,7 +3,7 @@ Fine-tunes Qwen2.5-0.5B via LoRA. """ - +from typing import Dict, List, Optional, Tuple, Union import os import time from collections import deque @@ -57,6 +57,7 @@ from prismatic.vla.datasets import RLDSDataset, RLDSBatchTransform from prismatic.vla.datasets.rlds.utils.data_utils import save_dataset_statistics from prismatic.models import load, load_vla +from prismatic.models.pi3_loader import load_pc_model @@ -70,6 +71,7 @@ class FinetuneConfig: vlm_path: str = "openvla/openvla-7b" # Path to OpenVLA model (on HuggingFace Hub or stored locally) use_minivlm: bool = False # resum_vla_path: str = "openvla/openvla-7b" # Path to OpenVLA model (on HuggingFace Hub or stored locally) + pi3_path: Path = Path("/home/ruihengwang/vla/VLA-Adapter/pretrained_models/pi3_checkpoint") # Dataset data_root_dir: Path = Path("datasets/rlds") # Directory containing RLDS datasets @@ -125,6 +127,9 @@ class FinetuneConfig: # revision version use_pro_version: bool = True # the version number phase: str = "Training" + use_3d: bool = False + dim_3d: int = 2048 + inject_layers: Union[int, List[int], str] = 0 # fmt: on @@ -188,6 +193,7 @@ def get_run_id(cfg) -> str: run_id += "--image_aug" if cfg.run_id_note is not None: run_id += f"--{cfg.run_id_note}" + run_id += f"--use_3d_{cfg.use_3d}_dim_{cfg.dim_3d}_inject_{cfg.inject_layers}" return run_id @@ -298,7 +304,8 @@ def run_forward_pass( num_patches, compute_diffusion_l1=False, use_pro_version=True, - cfg=None + cfg=None, + **kwargs ) -> Tuple[torch.Tensor, Dict[str, float]]: """ Compute model forward pass and metrics for both training and validation. @@ -330,7 +337,8 @@ def run_forward_pass( # Get ground-truth action labels ground_truth_actions = batch["actions"].to(device_id).to(torch.bfloat16) noise, noisy_actions, diffusion_timestep_embeddings = None, None, None - + pi3_model = kwargs.get("pi3_model", None) + img_1, img_2 = batch["pixel_values"][:, 0:3, :, :].to(device_id).to(torch.bfloat16), batch["pixel_values"][:, 6:9, :, :].to(device_id).to(torch.bfloat16) # VLA forward pass with torch.autocast("cuda", dtype=torch.bfloat16): output: CausalLMOutputWithPast = vla( @@ -346,13 +354,32 @@ def run_forward_pass( diffusion_timestep_embeddings=None, use_film=use_film, ) + if pi3_model is not None: + pi3_num_reg_token = 5 + + img_tensor = torch.stack([img_1, img_2], dim=1) # [B, 2, 3, H, W] where 2 indicates 2 views + B, N, _, H, W = img_tensor.shape + img_tensor = img_tensor.reshape((B*N, _, H, W)) + hidden = pi3_model.encoder(img_tensor, is_training=True) + if isinstance(hidden, dict): + hidden = hidden["x_norm_patchtokens"] + hidden, pos = pi3_model.decode(hidden, N, H, W) + hidden = hidden[:, pi3_num_reg_token:, :] + L_3d, dim_3d = hidden.shape[-2:] + hidden = hidden.reshape(B, -1, L_3d, dim_3d) + hidden = hidden.reshape(B, -1, dim_3d) # Get action masks needed for logging + #* batch["labels"] 是 L 个(L_a+L_lang),第一个是 BOS token,这样 :, 1: 是索引第 2-L 个。 + #* current_action_mask 索引了 L 中 L_a 里面前 6 个。 + #* next_action_mask 索引了 L 中 L_a 里后面 58 个(设定了 64 个 action tokens) ground_truth_token_ids = batch["labels"][:,1:].to(device_id) current_action_mask = get_current_action_mask(ground_truth_token_ids) next_actions_mask = get_next_actions_mask(ground_truth_token_ids) # Compute metrics for discrete action representation (next-token prediction) + + if not (use_l1_regression): loss = output.loss predicted_token_ids = output.logits[:, num_patches:-1].argmax(dim=2) @@ -394,7 +421,12 @@ def run_forward_pass( else: # Get last layer hidden states multi_layer_hidden_states = [] - + #* 每一层 [B, 1 + L_v + (L_a + L_lang -1), Dim] 的 hidden_states + #* text_hidden_states 就是 [B, L_a + L_lang -1, Dim] + #* actions_hidden_states 就对应索引取出 action 是 True 的部分: [B, 1, L_a, Dim] + #* task_latten_states 是 vision 部分,也就是 [B, 1, L_v, Dim] + #* 这二者 cat 在一起,也就是 [B, 1, L_v + L_a, Dim] + #* 若一共 H 层,那我们最后在 维度 1 上 cat 即可得到 [B, H, L_v + L_a, Dim],这就是输入给 action head 的。H 是中间层数。 for item in output.hidden_states[0:]: # last_hidden_states = output.hidden_states[-1] # (B, seq_len, D) # Get hidden states for text portion of prompt+response (after the vision patches) @@ -413,6 +445,7 @@ def run_forward_pass( proprio=batch["proprio"] if use_proprio else None, proprio_projector=proprio_projector if use_proprio else None, phase=cfg.phase, + hidden_3d=hidden.to(torch.bfloat16) ) loss = torch.nn.L1Loss()(predicted_actions, ground_truth_actions) @@ -719,6 +752,31 @@ def finetune(cfg: FinetuneConfig) -> None: # Create experiment run directory run_dir = cfg.run_root_dir / run_id os.makedirs(run_dir, exist_ok=True) + from omegaconf import OmegaConf + from dataclasses import asdict + import json + cfg_dict = cfg if isinstance(cfg, dict) else \ + OmegaConf.to_container(cfg) if OmegaConf.is_config(cfg) else \ + asdict(cfg) # dataclass + + # 2. Path 对象转字符串,保证可 JSON 序列化 + def _convert_path(obj): + if isinstance(obj, Path): + return str(obj) + if isinstance(obj, dict): + return {k: _convert_path(v) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + return [_convert_path(i) for i in obj] + return obj + + cfg_dict = _convert_path(cfg_dict) + + # 3. 写入 run_dir + config_json = run_dir / "config.json" + with config_json.open("w", encoding="utf-8") as f: + json.dump(cfg_dict, f, indent=2, ensure_ascii=False) + + print(f"Config saved to {config_json}") # GPU setup distributed_state = PartialState() @@ -728,7 +786,7 @@ def finetune(cfg: FinetuneConfig) -> None: # Initialize wandb logging if distributed_state.is_main_process: - wandb.init(project=cfg.wandb_project, name=f"ft+{run_id}", mode="offline") + wandb.init(project=cfg.wandb_project, name=f"ft+{run_id}", mode="online") #TODO: set online when necessary # Print detected constants print( @@ -775,7 +833,7 @@ def finetune(cfg: FinetuneConfig) -> None: processor = AutoProcessor.from_pretrained(cfg.config_file_path, trust_remote_code=True) if cfg.use_minivlm: - hf_token = '' + hf_token = '' if 'prism-qwen25-extra-dinosiglip-224px-0_5b' in cfg.vlm_path: vlm = load(cfg.vlm_path, hf_token=hf_token, load_for_training=True) @@ -891,9 +949,16 @@ def rename_state_dict_keys(state_dict, replace_map): "hidden_dim": vla.module.llm_dim, "action_dim": ACTION_DIM, "use_pro_version": cfg.use_pro_version, + "use_3d": cfg.use_3d, + "dim_3d": cfg.dim_3d, + "inject_layers": cfg.inject_layers }, to_bf16=True, ) + pi3_model = load_pc_model(cfg.pi3_path).to(device_id).to(torch.bfloat16) + pi3_model.eval() + for name, param in pi3_model.named_parameters(): + param.requires_grad = False # Get number of vision patches NUM_PATCHES = vla.module.vision_backbone.get_num_patches() * vla.module.vision_backbone.get_num_images_in_input() @@ -1033,6 +1098,7 @@ def rename_state_dict_keys(state_dict, replace_map): compute_diffusion_l1=compute_diffusion_l1, use_pro_version=cfg.use_pro_version, cfg=cfg, + pi3_model=pi3_model ) # Normalize loss to account for gradient accumulation diff --git a/vla_adapter.egg-info/PKG-INFO b/vla_adapter.egg-info/PKG-INFO deleted file mode 100644 index 7ad4e9b..0000000 --- a/vla_adapter.egg-info/PKG-INFO +++ /dev/null @@ -1,181 +0,0 @@ -Metadata-Version: 2.4 -Name: openvla-oft -Version: 0.0.1 -Summary: Fine-Tuning Vision-Language-Action Models: Optimizing Speed and Success -Author-email: Moo Jin Kim , Chelsea Finn , Percy Liang -License: MIT License - - Copyright (c) 2025 Moo Jin Kim, Chelsea Finn, Percy Liang. - - Permission is hereby granted, free of charge, to any person obtaining a copy - of this software and associated documentation files (the "Software"), to deal - in the Software without restriction, including without limitation the rights - to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - copies of the Software, and to permit persons to whom the Software is - furnished to do so, subject to the following conditions: - - The above copyright notice and this permission notice shall be included in all - copies or substantial portions of the Software. - - THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - SOFTWARE. - -Project-URL: homepage, https://github.com/moojink/openvla-oft -Project-URL: repository, https://github.com/moojink/openvla-oft -Project-URL: documentation, https://github.com/moojink/openvla-oft -Keywords: vision-language-actions models,fine-tuning,robot learning -Classifier: Development Status :: 3 - Alpha -Classifier: Intended Audience :: Developers -Classifier: Intended Audience :: Education -Classifier: Intended Audience :: Science/Research -Classifier: License :: OSI Approved :: MIT License -Classifier: Operating System :: OS Independent -Classifier: Programming Language :: Python :: 3 -Classifier: Programming Language :: Python :: 3.8 -Classifier: Programming Language :: Python :: 3.9 -Classifier: Programming Language :: Python :: 3.10 -Classifier: Programming Language :: Python :: 3 :: Only -Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence -Requires-Python: >=3.8 -Description-Content-Type: text/markdown -License-File: LICENSE -Requires-Dist: accelerate>=0.25.0 -Requires-Dist: draccus==0.8.0 -Requires-Dist: einops -Requires-Dist: huggingface_hub -Requires-Dist: json-numpy -Requires-Dist: jsonlines -Requires-Dist: matplotlib -Requires-Dist: peft==0.11.1 -Requires-Dist: protobuf -Requires-Dist: rich -Requires-Dist: sentencepiece==0.1.99 -Requires-Dist: timm==0.9.10 -Requires-Dist: tokenizers==0.19.1 -Requires-Dist: torch==2.2.0 -Requires-Dist: torchvision==0.17.0 -Requires-Dist: torchaudio==2.2.0 -Requires-Dist: transformers@ git+https://github.com/moojink/transformers-openvla-oft.git -Requires-Dist: wandb -Requires-Dist: tensorflow==2.15.0 -Requires-Dist: tensorflow_datasets==4.9.3 -Requires-Dist: tensorflow_graphics==2021.12.3 -Requires-Dist: dlimp@ git+https://github.com/moojink/dlimp_openvla -Requires-Dist: diffusers -Requires-Dist: imageio -Requires-Dist: uvicorn -Requires-Dist: fastapi -Requires-Dist: json-numpy -Provides-Extra: dev -Requires-Dist: black>=24.2.0; extra == "dev" -Requires-Dist: gpustat; extra == "dev" -Requires-Dist: ipython; extra == "dev" -Requires-Dist: pre-commit; extra == "dev" -Requires-Dist: ruff>=0.2.2; extra == "dev" -Provides-Extra: sagemaker -Requires-Dist: boto3; extra == "sagemaker" -Requires-Dist: sagemaker; extra == "sagemaker" -Dynamic: license-file - -# Fine-Tuning Vision-Language-Action Models: Optimizing Speed and Success - -**Project website: https://openvla-oft.github.io/** - -**Paper: https://arxiv.org/abs/2502.19645** - -**Summary video: https://youtu.be/T3Zkkr_NTSA** - -## System Requirements - -Inference: -* 1 GPU with ~16 GB VRAM for LIBERO sim benchmark tasks -* 1 GPU with ~18 GB VRAM for ALOHA robot tasks - -Training: -* Between 1-8 GPUs with 27-80 GB, depending on the desired training setup (with default bfloat16 data type). See [this FAQ on our project website](https://openvla-oft.github.io/#train-compute) for details. - -## Quick Start - -First, set up a conda environment (see instructions in [SETUP.md](SETUP.md)). - -Then, run the Python script below to download a pretrained OpenVLA-OFT checkpoint and run inference to generate an action chunk: - -```python -import pickle -from experiments.robot.libero.run_libero_eval import GenerateConfig -from experiments.robot.openvla_utils import get_action_head, get_processor, get_proprio_projector, get_vla, get_vla_action -from prismatic.vla.constants import NUM_ACTIONS_CHUNK, PROPRIO_DIM - -# Instantiate config (see class GenerateConfig in experiments/robot/libero/run_libero_eval.py for definitions) -cfg = GenerateConfig( - pretrained_checkpoint = "moojink/openvla-7b-oft-finetuned-libero-spatial", - use_l1_regression = True, - use_diffusion = False, - use_film = False, - num_images_in_input = 2, - use_proprio = True, - load_in_8bit = False, - load_in_4bit = False, - center_crop = True, - num_open_loop_steps = NUM_ACTIONS_CHUNK, - unnorm_key = "libero_spatial_no_noops", -) - -# Load OpenVLA-OFT policy and inputs processor -vla = get_vla(cfg) -processor = get_processor(cfg) - -# Load MLP action head to generate continuous actions (via L1 regression) -action_head = get_action_head(cfg, llm_dim=vla.llm_dim) - -# Load proprio projector to map proprio to language embedding space -proprio_projector = get_proprio_projector(cfg, llm_dim=vla.llm_dim, proprio_dim=PROPRIO_DIM) - -# Load sample observation: -# observation (dict): { -# "full_image": primary third-person image, -# "wrist_image": wrist-mounted camera image, -# "state": robot proprioceptive state, -# "task_description": task description, -# } -with open("experiments/robot/libero/sample_libero_spatial_observation.pkl", "rb") as file: - observation = pickle.load(file) - -# Generate robot action chunk (sequence of future actions) -actions = get_vla_action(cfg, vla, processor, observation, observation["task_description"], action_head, proprio_projector) -print("Generated action chunk:") -for act in actions: - print(act) -``` - -## Installation - -See [SETUP.md](SETUP.md) for instructions on setting up the conda environment. - -## Training and Evaluation - -See [LIBERO.md](LIBERO.md) for fine-tuning/evaluating on LIBERO simulation benchmark task suites. - -See [ALOHA.md](ALOHA.md) for fine-tuning/evaluating on real-world ALOHA robot tasks. - -## Support - -If you run into any issues, please open a new GitHub issue. If you do not receive a response within 2 business days, please email Moo Jin Kim (moojink@cs.stanford.edu) to bring the issue to his attention. - -## Citation - -If you use our code in your work, please cite [our paper](https://arxiv.org/abs/2502.19645): - -```bibtex -@article{kim2025fine, - title={Fine-Tuning Vision-Language-Action Models: Optimizing Speed and Success}, - author={Kim, Moo Jin and Finn, Chelsea and Liang, Percy}, - journal={arXiv preprint arXiv:2502.19645}, - year={2025} -} -``` diff --git a/vla_adapter.egg-info/SOURCES.txt b/vla_adapter.egg-info/SOURCES.txt deleted file mode 100644 index 7ed48a4..0000000 --- a/vla_adapter.egg-info/SOURCES.txt +++ /dev/null @@ -1,118 +0,0 @@ -LICENSE -README.md -pyproject.toml -experiments/robot/openvla_utils.py -experiments/robot/robot_utils.py -experiments/robot/aloha/aloha_utils.py -experiments/robot/aloha/constants.py -experiments/robot/aloha/preprocess_split_aloha_data.py -experiments/robot/aloha/real_env.py -experiments/robot/aloha/robot_utils.py -experiments/robot/aloha/run_aloha_eval.py -experiments/robot/bridge/bridgev2_utils.py -experiments/robot/bridge/run_bridgev2_eval.py -experiments/robot/bridge/widowx_env.py -experiments/robot/libero/libero_utils.py -experiments/robot/libero/regenerate_libero_dataset.py -experiments/robot/libero/run_libero_eval.py -openvla_oft.egg-info/PKG-INFO -openvla_oft.egg-info/SOURCES.txt -openvla_oft.egg-info/dependency_links.txt -openvla_oft.egg-info/requires.txt -openvla_oft.egg-info/top_level.txt -prismatic/__init__.py -prismatic/py.typed -prismatic/conf/__init__.py -prismatic/conf/datasets.py -prismatic/conf/models.py -prismatic/conf/vla.py -prismatic/extern/__init__.py -prismatic/extern/hf/__init__.py -prismatic/extern/hf/configuration_prismatic.py -prismatic/extern/hf/modeling_prismatic.py -prismatic/extern/hf/processing_prismatic.py -prismatic/models/__init__.py -prismatic/models/action_heads.py -prismatic/models/film_vit_wrapper.py -prismatic/models/load.py -prismatic/models/materialize.py -prismatic/models/projectors.py -prismatic/models/registry.py -prismatic/models/backbones/__init__.py -prismatic/models/backbones/llm/__init__.py -prismatic/models/backbones/llm/base_llm.py -prismatic/models/backbones/llm/llama2.py -prismatic/models/backbones/llm/mistral.py -prismatic/models/backbones/llm/phi.py -prismatic/models/backbones/llm/prompting/__init__.py -prismatic/models/backbones/llm/prompting/base_prompter.py -prismatic/models/backbones/llm/prompting/llama2_chat_prompter.py -prismatic/models/backbones/llm/prompting/mistral_instruct_prompter.py -prismatic/models/backbones/llm/prompting/phi_prompter.py -prismatic/models/backbones/llm/prompting/vicuna_v15_prompter.py -prismatic/models/backbones/vision/__init__.py -prismatic/models/backbones/vision/base_vision.py -prismatic/models/backbones/vision/clip_vit.py -prismatic/models/backbones/vision/dinoclip_vit.py -prismatic/models/backbones/vision/dinosiglip_vit.py -prismatic/models/backbones/vision/dinov2_vit.py -prismatic/models/backbones/vision/in1k_vit.py -prismatic/models/backbones/vision/siglip_vit.py -prismatic/models/vlas/__init__.py -prismatic/models/vlas/openvla.py -prismatic/models/vlms/__init__.py -prismatic/models/vlms/base_vlm.py -prismatic/models/vlms/prismatic.py -prismatic/overwatch/__init__.py -prismatic/overwatch/overwatch.py -prismatic/preprocessing/__init__.py -prismatic/preprocessing/download.py -prismatic/preprocessing/materialize.py -prismatic/preprocessing/datasets/__init__.py -prismatic/preprocessing/datasets/datasets.py -prismatic/training/__init__.py -prismatic/training/materialize.py -prismatic/training/metrics.py -prismatic/training/train_utils.py -prismatic/training/strategies/__init__.py -prismatic/training/strategies/base_strategy.py -prismatic/training/strategies/ddp.py -prismatic/training/strategies/fsdp.py -prismatic/util/__init__.py -prismatic/util/batching_utils.py -prismatic/util/data_utils.py -prismatic/util/nn_utils.py -prismatic/util/torch_utils.py -prismatic/vla/__init__.py -prismatic/vla/action_tokenizer.py -prismatic/vla/constants.py -prismatic/vla/materialize.py -prismatic/vla/datasets/__init__.py -prismatic/vla/datasets/datasets.py -prismatic/vla/datasets/rlds/__init__.py -prismatic/vla/datasets/rlds/dataset.py -prismatic/vla/datasets/rlds/obs_transforms.py -prismatic/vla/datasets/rlds/traj_transforms.py -prismatic/vla/datasets/rlds/oxe/__init__.py -prismatic/vla/datasets/rlds/oxe/configs.py -prismatic/vla/datasets/rlds/oxe/materialize.py -prismatic/vla/datasets/rlds/oxe/mixtures.py -prismatic/vla/datasets/rlds/oxe/transforms.py -prismatic/vla/datasets/rlds/oxe/utils/droid_utils.py -prismatic/vla/datasets/rlds/utils/__init__.py -prismatic/vla/datasets/rlds/utils/data_utils.py -prismatic/vla/datasets/rlds/utils/goal_relabeling.py -prismatic/vla/datasets/rlds/utils/task_augmentation.py -scripts/generate.py -scripts/preprocess.py -scripts/pretrain.py -scripts/additional-datasets/lrv_instruct.py -scripts/additional-datasets/lvis_instruct_4v.py -scripts/extern/convert_prismatic_weights_to_hf.py -scripts/extern/verify_prismatic.py -vla-scripts/deploy.py -vla-scripts/finetune.py -vla-scripts/merge_lora_weights_and_save.py -vla-scripts/train.py -vla-scripts/extern/convert_openvla_weights_to_hf.py -vla-scripts/extern/verify_openvla.py \ No newline at end of file diff --git a/vla_adapter.egg-info/dependency_links.txt b/vla_adapter.egg-info/dependency_links.txt deleted file mode 100644 index 8b13789..0000000 --- a/vla_adapter.egg-info/dependency_links.txt +++ /dev/null @@ -1 +0,0 @@ - diff --git a/vla_adapter.egg-info/requires.txt b/vla_adapter.egg-info/requires.txt deleted file mode 100644 index 1dd95a1..0000000 --- a/vla_adapter.egg-info/requires.txt +++ /dev/null @@ -1,38 +0,0 @@ -accelerate>=0.25.0 -draccus==0.8.0 -einops -huggingface_hub -json-numpy -jsonlines -matplotlib -peft==0.11.1 -protobuf -rich -sentencepiece==0.1.99 -timm==0.9.10 -tokenizers==0.19.1 -torch==2.2.0 -torchvision==0.17.0 -torchaudio==2.2.0 -transformers@ git+https://github.com/moojink/transformers-openvla-oft.git -wandb -tensorflow==2.15.0 -tensorflow_datasets==4.9.3 -tensorflow_graphics==2021.12.3 -dlimp@ git+https://github.com/moojink/dlimp_openvla -diffusers -imageio -uvicorn -fastapi -json-numpy - -[dev] -black>=24.2.0 -gpustat -ipython -pre-commit -ruff>=0.2.2 - -[sagemaker] -boto3 -sagemaker diff --git a/vla_adapter.egg-info/top_level.txt b/vla_adapter.egg-info/top_level.txt deleted file mode 100644 index 16a23a4..0000000 --- a/vla_adapter.egg-info/top_level.txt +++ /dev/null @@ -1,4 +0,0 @@ -experiments -prismatic -scripts -vla-scripts