diff --git a/rfdetr/main.py b/rfdetr/main.py index 7c18a25f6..e2f61adf1 100644 --- a/rfdetr/main.py +++ b/rfdetr/main.py @@ -440,6 +440,8 @@ def export(self, output_dir="output", infer_dir=None, simplify=False, backbone_ device = self.device model = deepcopy(self.model.to("cpu")) + if backbone_only: + model = model.backbone model.to(device) os.makedirs(output_dir, exist_ok=True) @@ -453,11 +455,25 @@ def export(self, output_dir="output", infer_dir=None, simplify=False, backbone_ input_tensors = make_infer_image(infer_dir, shape, batch_size, device).to(device) input_names = ['input'] output_names = ['features'] if backbone_only else ['dets', 'labels'] - dynamic_axes = None + dynamic_axes = { + 'input': {0: 'batch_size', 2: 'height', 3: 'width'}, + } + if backbone_only: + dynamic_axes.update({ + 'dets': {0: 'batch_size'}, + 'labels': {0: 'batch_size'}, + }) + else: + dynamic_axes.update({ + 'features': {0: 'batch_size', 2: 'num_patches_height', 3: 'num_patches_width'}, + }) + self.model.eval() with torch.no_grad(): if backbone_only: - features = model(input_tensors) + features = model( + utils.nested_tensor_from_tensor_list([input_tensors[0]]) + )[0][0].tensors print(f"PyTorch inference output shape: {features.shape}") else: outputs = model(input_tensors) diff --git a/rfdetr/models/backbone/dinov2.py b/rfdetr/models/backbone/dinov2.py index 82d8f2277..dc354d600 100644 --- a/rfdetr/models/backbone/dinov2.py +++ b/rfdetr/models/backbone/dinov2.py @@ -67,61 +67,6 @@ def export(self): if self._export: return self._export = True - shape = self.shape - def make_new_interpolated_pos_encoding( - position_embeddings, patch_size, height, width - ): - - num_positions = position_embeddings.shape[1] - 1 - dim = position_embeddings.shape[-1] - height = height // patch_size - width = width // patch_size - - class_pos_embed = position_embeddings[:, 0] - patch_pos_embed = position_embeddings[:, 1:] - - # Reshape and permute - patch_pos_embed = patch_pos_embed.reshape( - 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim - ) - patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) - - # Use bilinear interpolation without antialias - patch_pos_embed = F.interpolate( - patch_pos_embed, - size=(height, width), - mode="bicubic", - align_corners=False, - antialias=True, - ) - - # Reshape back - patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).reshape(1, -1, dim) - return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) - - # If the shape of self.encoder.embeddings.position_embeddings - # matches the shape of your new tensor, use copy_: - with torch.no_grad(): - new_positions = make_new_interpolated_pos_encoding( - self.encoder.embeddings.position_embeddings, - self.encoder.config.patch_size, - shape[0], - shape[1], - ) - # Create a new Parameter with the new size - old_interpolate_pos_encoding = self.encoder.embeddings.interpolate_pos_encoding - def new_interpolate_pos_encoding(self_mod, embeddings, height, width): - num_patches = embeddings.shape[1] - 1 - num_positions = self_mod.position_embeddings.shape[1] - 1 - if num_patches == num_positions and height == width: - return self_mod.position_embeddings - return old_interpolate_pos_encoding(embeddings, height, width) - - self.encoder.embeddings.position_embeddings = nn.Parameter(new_positions) - self.encoder.embeddings.interpolate_pos_encoding = types.MethodType( - new_interpolate_pos_encoding, - self.encoder.embeddings - ) def forward(self, x): assert x.shape[2] % 14 == 0 and x.shape[3] % 14 == 0, f"Dinov2 requires input shape to be divisible by 14, but got {x.shape}" diff --git a/rfdetr/models/backbone/dinov2_with_windowed_attn.py b/rfdetr/models/backbone/dinov2_with_windowed_attn.py index 93ca6bf9c..ba62a0028 100644 --- a/rfdetr/models/backbone/dinov2_with_windowed_attn.py +++ b/rfdetr/models/backbone/dinov2_with_windowed_attn.py @@ -283,7 +283,8 @@ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: size=(torch_int(height), torch_int(width)), # Explicit size instead of scale_factor mode="bicubic", align_corners=False, - antialias=True, + # True by default, False during export + antialias=not torch.onnx.is_in_onnx_export(), ).to(dtype=target_dtype) # Validate output dimensions if not tracing diff --git a/rfdetr/models/ops/modules/ms_deform_attn.py b/rfdetr/models/ops/modules/ms_deform_attn.py index 8af2bcc23..efb02c127 100644 --- a/rfdetr/models/ops/modules/ms_deform_attn.py +++ b/rfdetr/models/ops/modules/ms_deform_attn.py @@ -95,7 +95,7 @@ def _reset_parameters(self): def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None): - """ + r""" :param query (N, Length_{query}, C) :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes diff --git a/rfdetr/models/transformer.py b/rfdetr/models/transformer.py index c7b32a6d6..0cfcdc740 100644 --- a/rfdetr/models/transformer.py +++ b/rfdetr/models/transformer.py @@ -68,7 +68,7 @@ def gen_sineembed_for_position(pos_tensor, dim=128): def gen_encoder_output_proposals(memory, memory_padding_mask, spatial_shapes, unsigmoid=True): - """ + r""" Input: - memory: bs, \sum{hw}, d_model - memory_padding_mask: bs, \sum{hw} @@ -198,12 +198,12 @@ def forward(self, srcs, masks, pos_embeds, refpoint_embed, query_feat): src_flatten = [] mask_flatten = [] if masks is not None else None lvl_pos_embed_flatten = [] - spatial_shapes = [] + spatial_shapes = torch.empty((len(srcs), 2), device=srcs[0].device, dtype=torch.long) valid_ratios = [] if masks is not None else None for lvl, (src, pos_embed) in enumerate(zip(srcs, pos_embeds)): bs, c, h, w = src.shape - spatial_shape = (h, w) - spatial_shapes.append(spatial_shape) + spatial_shapes[lvl, 0] = h + spatial_shapes[lvl, 1] = w src = src.flatten(2).transpose(1, 2) # bs, hw, c pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c @@ -217,7 +217,6 @@ def forward(self, srcs, masks, pos_embeds, refpoint_embed, query_feat): mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw} valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) # bs, \sum{hxw}, c - spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=memory.device) level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) if self.two_stage: