Skip to content
8 changes: 7 additions & 1 deletion src/rfdetr/detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ def export(
force: bool = False,
shape: tuple = None,
batch_size: int = 1,
dynamic_batch: bool = False,
**kwargs,
) -> None:
"""Export the trained model to ONNX format.
Expand All @@ -296,6 +297,8 @@ def export(
force: Deprecated and ignored.
shape: ``(height, width)`` tuple; defaults to square at model resolution.
batch_size: Static batch size to bake into the ONNX graph.
dynamic_batch: If True, export with a dynamic batch dimension
so the ONNX model accepts variable batch sizes at runtime.
**kwargs: Additional keyword arguments forwarded to export_onnx.
"""
logger.info("Exporting model to ONNX format")
Expand Down Expand Up @@ -329,7 +332,10 @@ def export(
else:
output_names = ["dets", "labels"]

dynamic_axes = None
if dynamic_batch:
dynamic_axes = {name: {0: "batch"} for name in input_names + output_names}
else:
dynamic_axes = None
model.eval()
with torch.no_grad():
if backbone_only:
Expand Down
5 changes: 4 additions & 1 deletion src/rfdetr/export/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,10 @@ def main(args):
output_names = ["dets", "labels", "masks"]
else:
output_names = ["dets", "labels"]
dynamic_axes = None
if getattr(args, "dynamic_batch", False):
dynamic_axes = {name: {0: "batch"} for name in input_names + output_names}
else:
dynamic_axes = None
# Run model inference in pytorch mode
model.eval().to("cuda")
input_tensors = input_tensors.to("cuda")
Expand Down
2 changes: 1 addition & 1 deletion src/rfdetr/models/backbone/projector.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def forward(self, x):
TODO: this is a hack to avoid overflow when using fp16
"""
x = x.permute(0, 2, 3, 1)
x = F.layer_norm(x, (x.size(3),), self.weight, self.bias, self.eps)
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
x = x.permute(0, 3, 1, 2)
return x

Expand Down
24 changes: 12 additions & 12 deletions src/rfdetr/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ def gen_encoder_output_proposals(memory, memory_padding_mask, spatial_shapes, un
valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
else:
valid_H = torch.tensor([H_ for _ in range(N_)], device=memory.device)
valid_W = torch.tensor([W_ for _ in range(N_)], device=memory.device)
valid_H = H_.expand(N_)
valid_W = W_.expand(N_)

grid_y, grid_x = torch.meshgrid(
torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device),
Expand Down Expand Up @@ -224,12 +224,18 @@ def forward(self, srcs, masks, pos_embeds, refpoint_embed, query_feat):
src_flatten = []
mask_flatten = [] if masks is not None else None
lvl_pos_embed_flatten = []
spatial_shapes = []
# Build spatial_shapes as a tensor directly so that the ONNX tracer
# can track h/w symbolically instead of baking them in as constants.
spatial_shapes = torch.empty(
(len(srcs), 2),
device=srcs[0].device,
dtype=torch.long,
)
valid_ratios = [] if masks is not None else None
for lvl, (src, pos_embed) in enumerate(zip(srcs, pos_embeds)):
bs, c, h, w = src.shape
spatial_shape = (h, w)
spatial_shapes.append(spatial_shape)
spatial_shapes[lvl, 0] = h
spatial_shapes[lvl, 1] = w

src = src.flatten(2).transpose(1, 2) # bs, hw, c
pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c
Expand All @@ -243,17 +249,11 @@ def forward(self, srcs, masks, pos_embeds, refpoint_embed, query_feat):
mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw}
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) # bs, \sum{hxw}, c
# Keep a plain Python list of (H, W) int pairs for gen_encoder_output_proposals.
# The tensor form is only needed by ms_deform_attn and level_start_index.
# Passing Python ints avoids .item() calls inside the function, which would
# cause torch.compile graph breaks on loop-accumulated slice indices.
spatial_shapes_hw = list(spatial_shapes)
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=memory.device)
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))

if self.two_stage:
output_memory, output_proposals = gen_encoder_output_proposals(
memory, mask_flatten, spatial_shapes_hw, unsigmoid=not self.bbox_reparam
memory, mask_flatten, spatial_shapes, unsigmoid=not self.bbox_reparam
)
# group detr for first stage
refpoint_embed_ts, memory_ts, boxes_ts = [], [], []
Expand Down
Loading