Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 25 additions & 20 deletions src/lightly_train/_commands/export_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,8 @@ def export_onnx(
out: PathLike,
checkpoint: PathLike,
batch_size: int = 1,
num_channels: int = 3,
height: int = 224,
width: int = 224,
height: int | None = None,
width: int | None = None,
precision: Literal["32-true", "16-true"] = "32-true",
simplify: bool = True,
verify: bool = True,
Expand All @@ -90,8 +89,6 @@ def export_onnx(
Path to the LightlyTrain checkpoint file to export the model from.
batch_size:
Batch size of the input tensor.
num_channels:
Number of channels in input tensor.
height:
Height of the input tensor.
width:
Expand All @@ -118,9 +115,8 @@ def _export_task(
checkpoint: PathLike,
format: Literal["onnx"],
batch_size: int = 1,
num_channels: int = 3,
height: int = 224,
width: int = 224,
height: int | None = None,
width: int | None = None,
precision: Literal["32-true", "16-true"] = "32-true",
simplify: bool = True,
verify: bool = True,
Expand All @@ -138,12 +134,12 @@ def _export_task(
Format to save the model in.
batch_size:
Batch size of the input tensor.
num_channels:
Number of channels in input tensor.
height:
Height of the input tensor. For efficiency reasons we recomment this to be the same as width.
Height of the input tensor. If not specified it will be the same height that the model was trained in.
For efficiency reasons we recommend this to be the same as width.
width:
Width of the input tensor. For efficiency reasons we recomment this to be the same as height.
Width of the input tensor. If not specified it will be the same width that the model was trained in.
For efficiency reasons we recommend this to be the same as height.
precision:
OnnxPrecision.F32_TRUE for float32 precision or OnnxPrecision.F16_TRUE for float16 precision.
simplify:
Expand Down Expand Up @@ -181,16 +177,26 @@ def _export_task_from_config(config: ExportTaskConfig) -> None:
)
task_model.eval()

height = config.height
width = config.width
# TODO we might also use task_model.backbone.in_chans
num_channels = len(task_model.image_normalize["mean"])

if height is None:
height = task_model.image_size[0]
if width is None:
width = task_model.image_size[1]

# Export the model to ONNX format
# TODO(Yutong, 07/25): support more formats (may use ONNX as the intermediate format)
if config.format == "onnx":
# The DinoVisionTransformer _predict method currently raises a RuntimeException when the image size is not
# divisible by the patch size. This only occurs during ONNX export as otherwise we interpolate the input
# image to the correct size.
patch_size = task_model.backbone.patch_size
if not (config.height % patch_size == 0 and config.width % patch_size == 0):
if not (height % patch_size == 0 and width % patch_size == 0):
raise ValueError(
f"Height {config.height} and width {config.width} must be a multiple of patch size {patch_size}."
f"Height {height} and width {width} must be a multiple of patch size {patch_size}."
)

# Get the device of the model to ensure dummy input is on the same device
Expand All @@ -200,9 +206,9 @@ def _export_task_from_config(config: ExportTaskConfig) -> None:

dummy_input = torch.randn(
config.batch_size,
config.num_channels,
config.height,
config.width,
num_channels,
height,
width,
requires_grad=False,
device=model_device,
dtype=onnx_dtype,
Expand Down Expand Up @@ -285,9 +291,8 @@ class ExportTaskConfig(PydanticConfig):
checkpoint: PathLike
format: Literal["onnx"]
batch_size: int = 1
num_channels: int = 3
height: int = 224
width: int = 224
height: int | None = None
width: int | None = None
precision: OnnxPrecision = OnnxPrecision.F32_TRUE
simplify: bool = True
verify: bool = True
Expand Down
71 changes: 51 additions & 20 deletions tests/_commands/test_export_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
# LICENSE file in the root directory of this source tree.
#

from __future__ import annotations

import sys
from pathlib import Path

Expand All @@ -19,18 +21,18 @@
from .. import helpers


@pytest.fixture(scope="module")
def dinov2_vits14_eomt_checkpoint(tmp_path_factory: pytest.TempPathFactory) -> Path:
tmp = tmp_path_factory.mktemp("tmp")
directory = tmp
def create_dinov2_vits14_eomt_test_checkpoint(
directory: Path, num_channels: int = 3
) -> Path:
out = directory / "out"
train_images = directory / "train_images"
train_masks = directory / "train_masks"
val_images = directory / "val_images"
val_masks = directory / "val_masks"
helpers.create_images(train_images)
mode = "RGBA" if num_channels == 4 else "RGB"
helpers.create_images(train_images, num_channels=num_channels, mode=mode)
helpers.create_masks(train_masks)
helpers.create_images(val_images)
helpers.create_images(val_images, num_channels=num_channels, mode=mode)
helpers.create_masks(val_masks)

lightly_train.train_semantic_segmentation(
Expand All @@ -50,6 +52,7 @@ def dinov2_vits14_eomt_checkpoint(tmp_path_factory: pytest.TempPathFactory) -> P
},
},
model="dinov2/vits14-eomt",
transform_args={"num_channels": num_channels},
# The operator 'aten::upsample_bicubic2d.out' raises a NotImplementedError
# on macOS with MPS backend.
accelerator="auto" if not sys.platform.startswith("darwin") else "cpu",
Expand All @@ -64,20 +67,38 @@ def dinov2_vits14_eomt_checkpoint(tmp_path_factory: pytest.TempPathFactory) -> P
return checkpoint_path


@pytest.fixture(scope="module")
def dinov2_vits14_eomt_checkpoint(tmp_path_factory: pytest.TempPathFactory) -> Path:
tmp = tmp_path_factory.mktemp("tmp")
return create_dinov2_vits14_eomt_test_checkpoint(directory=tmp)


@pytest.fixture(scope="module")
def dinov2_vits14_eomt_4_channels_checkpoint(
tmp_path_factory: pytest.TempPathFactory,
) -> Path:
tmp = tmp_path_factory.mktemp("tmp")
return create_dinov2_vits14_eomt_test_checkpoint(directory=tmp, num_channels=4)


onnx_export_testset = [
(1, 42, 154, OnnxPrecision.F32_TRUE),
(1, 42, 154, OnnxPrecision.F32_TRUE),
(2, 14, 14, OnnxPrecision.F32_TRUE),
(3, 140, 280, OnnxPrecision.F16_TRUE),
(4, 266, 28, OnnxPrecision.F16_TRUE),
(1, 3, 42, 154, OnnxPrecision.F32_TRUE),
(1, 4, 154, 42, OnnxPrecision.F32_TRUE),
(2, 3, 14, 14, OnnxPrecision.F32_TRUE),
(2, 4, None, None, OnnxPrecision.F32_TRUE),
(3, 3, 140, None, OnnxPrecision.F16_TRUE),
(4, 3, None, 28, OnnxPrecision.F16_TRUE),
# (4, 4, None, 28, OnnxPrecision.F16_TRUE), # TODO this test currently fails due to rounding deviations before argmax
]


@pytest.mark.skipif(
sys.platform.startswith("win"),
reason=("Fails on Windows because of potential memory issues"),
)
@pytest.mark.parametrize("batch_size,height,width,precision", onnx_export_testset)
@pytest.mark.parametrize(
"batch_size,num_channels,height,width,precision", onnx_export_testset
)
@pytest.mark.skipif(
sys.version_info < (3, 9),
reason="Requires Python 3.9 or higher for image preprocessing.",
Expand All @@ -89,21 +110,31 @@ def dinov2_vits14_eomt_checkpoint(tmp_path_factory: pytest.TempPathFactory) -> P
@pytest.mark.skipif(not RequirementCache("onnxslim"), reason="onnxslim not installed")
def test_onnx_export(
batch_size: int,
height: int,
width: int,
num_channels: int,
height: int | None,
width: int | None,
precision: OnnxPrecision,
dinov2_vits14_eomt_checkpoint: Path,
dinov2_vits14_eomt_4_channels_checkpoint: Path,
tmp_path: Path,
) -> None:
import onnx
import onnxruntime as ort

# arrange
model = lightly_train.load_model_from_checkpoint(
dinov2_vits14_eomt_checkpoint, device="cpu"
)
checkpoint = {
3: dinov2_vits14_eomt_checkpoint,
4: dinov2_vits14_eomt_4_channels_checkpoint,
}[num_channels]
model = lightly_train.load_model_from_checkpoint(checkpoint, device="cpu")
if height is None:
height = model.image_size[0]
if width is None:
width = model.image_size[1]
onnx_path = tmp_path / "model.onnx"
validation_input = torch.randn(batch_size, 3, height, width, device="cpu")
validation_input = torch.randn(
batch_size, num_channels, height, width, device="cpu"
)
expected_outputs = model(validation_input)
expected_output_dtypes = [torch.int64, precision.torch()]
# We use torch.testing.assert_close to check if the model outputs the same as when we run the exported
Expand All @@ -114,7 +145,7 @@ def test_onnx_export(
# act
lightly_train.export_onnx(
out=onnx_path,
checkpoint=dinov2_vits14_eomt_checkpoint,
checkpoint=checkpoint,
height=height,
width=width,
precision=precision.value,
Expand Down Expand Up @@ -207,7 +238,7 @@ def test_onnx_export__width_not_patch_size_multiple_fails(
height = patch_size
width = patch_size - 1

# act
# actf
with pytest.raises(
ValueError,
match=(
Expand Down