diff --git a/CHANGELOG.md b/CHANGELOG.md index a27e47e12..537fb3542 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,10 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -### Changed - +- Add support for training models on multi-channel images with `transform_args={"num_channels": 4}`. - Add support for using custom mask names for the inputs in semantic segmentation. -- Add `precision` flag to ONNX export task to specify if we export with float16 or float32 precision. + +### Changed ### Deprecated diff --git a/src/lightly_train/_commands/common_helpers.py b/src/lightly_train/_commands/common_helpers.py index e233270d2..e27713076 100644 --- a/src/lightly_train/_commands/common_helpers.py +++ b/src/lightly_train/_commands/common_helpers.py @@ -549,6 +549,7 @@ def get_dataset_mmap_filenames( def get_dataset( data: PathLike | Sequence[PathLike] | Dataset[DatasetItem], transform: Transform, + num_channels: int, mmap_filepath: Path | None, out_dir: Path, ) -> Dataset[DatasetItem]: @@ -583,6 +584,7 @@ def get_dataset( mmap_filepath=mmap_filepath, ), transform=transform, + num_channels=num_channels, mask_dir=Path(mask_dir) if mask_dir is not None else None, ) @@ -609,6 +611,7 @@ def get_dataset( mmap_filepath=mmap_filepath, ), transform=transform, + num_channels=num_channels, ) else: raise ValueError( diff --git a/src/lightly_train/_commands/embed.py b/src/lightly_train/_commands/embed.py index e7025c06f..55e98580b 100644 --- a/src/lightly_train/_commands/embed.py +++ b/src/lightly_train/_commands/embed.py @@ -136,6 +136,7 @@ def embed_from_config(config: EmbedConfig) -> None: dataset = common_helpers.get_dataset( data=config.data, transform=transform, + num_channels=len(checkpoint_instance.lightly_train.normalize_args.mean), mmap_filepath=mmap_filepath, out_dir=out_path, ) diff --git a/src/lightly_train/_commands/train.py b/src/lightly_train/_commands/train.py index fb29d7ea9..86202b77d 100644 --- a/src/lightly_train/_commands/train.py +++ b/src/lightly_train/_commands/train.py @@ -30,6 +30,7 @@ from lightly_train._commands.common_helpers import ModelFormat from lightly_train._configs import omegaconf_utils, validate from lightly_train._configs.config import PydanticConfig +from lightly_train._configs.validate import no_auto from lightly_train._loggers import logger_helpers from lightly_train._loggers.logger_args import LoggerArgs from lightly_train._methods import method_helpers @@ -295,6 +296,7 @@ def train_from_config(config: TrainConfig) -> None: dataset = common_helpers.get_dataset( data=config.data, transform=transform_instance, + num_channels=no_auto(transform_instance.transform_args.num_channels), mmap_filepath=mmap_filepath, out_dir=out_dir, ) @@ -310,7 +312,9 @@ def train_from_config(config: TrainConfig) -> None: epochs=config.epochs, ) wrapped_model = package_helpers.get_wrapped_model( - model=config.model, model_args=config.model_args + model=config.model, + model_args=config.model_args, + num_input_channels=no_auto(transform_instance.transform_args.num_channels), ) embedding_model = train_helpers.get_embedding_model( wrapped_model=wrapped_model, embed_dim=config.embed_dim @@ -398,6 +402,7 @@ def train_from_config(config: TrainConfig) -> None: optimizer_args=config.optim_args, embedding_model=embedding_model, global_batch_size=config.batch_size, + num_input_channels=no_auto(transform_instance.transform_args.num_channels), ) train_helpers.load_checkpoint( checkpoint=config.checkpoint, diff --git a/src/lightly_train/_commands/train_helpers.py b/src/lightly_train/_commands/train_helpers.py index 4af9ce2d6..e8a6eabd6 100644 --- a/src/lightly_train/_commands/train_helpers.py +++ b/src/lightly_train/_commands/train_helpers.py @@ -50,19 +50,23 @@ def get_transform_args( ) -> MethodTransformArgs: logger.debug(f"Getting transform args for method '{method}'.") logger.debug(f"Using additional transform arguments {transform_args}.") - if isinstance(transform_args, MethodTransformArgs): - return transform_args - - method_cls = method_helpers.get_method_cls(method) - transform_cls = method_cls.transform_cls() - transform_args_cls = transform_cls.transform_args_cls() - - if transform_args is None: - # We need to typeignore here because a MethodTransformArgs might not have - # defaults for all fields, while its children do. - return transform_args_cls() # type: ignore[call-arg] + if not isinstance(transform_args, MethodTransformArgs): + method_cls = method_helpers.get_method_cls(method) + transform_cls = method_cls.transform_cls() + transform_args_cls = transform_cls.transform_args_cls() + + if transform_args is None: + # We need to typeignore here because a MethodTransformArgs might not have + # defaults for all fields, while its children do. + transform_args = transform_args_cls() # type: ignore[call-arg] + else: + transform_args = validate.pydantic_model_validate( + transform_args_cls, transform_args + ) - return validate.pydantic_model_validate(transform_args_cls, transform_args) + transform_args.resolve_auto() + transform_args.resolve_incompatible() + return transform_args def get_transform( @@ -326,6 +330,7 @@ def get_method( optimizer_args: OptimizerArgs, embedding_model: EmbeddingModel, global_batch_size: int, + num_input_channels: int, ) -> Method: logger.debug(f"Getting method for '{method_cls.__name__}'") return method_cls( @@ -333,6 +338,7 @@ def get_method( optimizer_args=optimizer_args, embedding_model=embedding_model, global_batch_size=global_batch_size, + num_input_channels=num_input_channels, ) diff --git a/src/lightly_train/_commands/train_task_helpers.py b/src/lightly_train/_commands/train_task_helpers.py index 5d0541b75..b970372d3 100644 --- a/src/lightly_train/_commands/train_task_helpers.py +++ b/src/lightly_train/_commands/train_task_helpers.py @@ -54,6 +54,9 @@ TrainModelArgs, ) from lightly_train._train_task_state import TrainTaskState +from lightly_train._transforms.semantic_segmentation_transform import ( + SemanticSegmentationTransform, +) from lightly_train._transforms.task_transform import ( TaskTransform, TaskTransformArgs, @@ -229,10 +232,16 @@ def get_transform_args( train_transform_args_cls, transform_args ) train_transform_args.resolve_auto() + train_transform_args.resolve_incompatible() # Take defaults from train transform. val_args_dict = train_transform_args.model_dump( - include={"image_size": True, "normalize": True, "ignore_index": True} + include={ + "image_size": True, + "normalize": True, + "ignore_index": True, + "num_channels": True, + } ) # Overwrite with user provided val args. val_args_dict.update(val_args) @@ -240,6 +249,7 @@ def get_transform_args( val_transform_args_cls, val_args_dict ) val_transform_args.resolve_auto() + val_transform_args.resolve_incompatible() logger.debug( f"Resolved train transform args {pretty_format_args(train_transform_args.model_dump())}" @@ -408,6 +418,8 @@ def get_dataset( image_info = dataset_args.list_image_info() dataset_cls = dataset_args.get_dataset_cls() + # TODO(Guarin, 08/25): Relax this when we add object detection. + assert isinstance(transform, SemanticSegmentationTransform) return dataset_cls( dataset_args=dataset_args, image_info=get_dataset_mmap_file( diff --git a/src/lightly_train/_data/image_dataset.py b/src/lightly_train/_data/image_dataset.py index ea664c039..6e2c23a99 100644 --- a/src/lightly_train/_data/image_dataset.py +++ b/src/lightly_train/_data/image_dataset.py @@ -27,20 +27,31 @@ def __init__( image_dir: Path | None, image_filenames: Sequence[ImageFilename], transform: Transform, + num_channels: int, mask_dir: Path | None = None, ): self.image_dir = image_dir self.image_filenames = image_filenames self.mask_dir = mask_dir self.transform = transform + self.num_channels = num_channels - try: - self.image_mode = ImageMode(Env.LIGHTLY_TRAIN_IMAGE_MODE.value) - except ValueError: + image_mode = ( + None + if Env.LIGHTLY_TRAIN_IMAGE_MODE.value is None + else ImageMode(Env.LIGHTLY_TRAIN_IMAGE_MODE.value) + ) + if image_mode is None: + image_mode = ( + ImageMode.RGB if self.num_channels == 3 else ImageMode.UNCHANGED + ) + + if image_mode not in (ImageMode.RGB, ImageMode.UNCHANGED): raise ValueError( - f'Invalid image mode: {Env.LIGHTLY_TRAIN_IMAGE_MODE.name}="{Env.LIGHTLY_TRAIN_IMAGE_MODE.value}". ' - "Supported modes are 'RGB' and 'UNCHANGED'." + f"Invalid image mode: '{image_mode}'. " + f"Supported modes are '{[ImageMode.RGB.value, ImageMode.UNCHANGED.value]}'." ) + self.image_mode = image_mode def __getitem__(self, idx: int) -> DatasetItem: filename = self.image_filenames[idx] diff --git a/src/lightly_train/_data/mask_semantic_segmentation_dataset.py b/src/lightly_train/_data/mask_semantic_segmentation_dataset.py index d1525e638..3ee9ace58 100644 --- a/src/lightly_train/_data/mask_semantic_segmentation_dataset.py +++ b/src/lightly_train/_data/mask_semantic_segmentation_dataset.py @@ -21,7 +21,10 @@ from lightly_train._data.file_helpers import ImageMode from lightly_train._data.task_data_args import TaskDataArgs from lightly_train._env import Env -from lightly_train._transforms.task_transform import TaskTransform +from lightly_train._transforms.semantic_segmentation_transform import ( + SemanticSegmentationTransform, + SemanticSegmentationTransformArgs, +) from lightly_train.types import ( BinaryMasksDict, MaskSemanticSegmentationDatasetItem, @@ -39,7 +42,7 @@ def __init__( self, dataset_args: MaskSemanticSegmentationDatasetArgs, image_info: Sequence[dict[str, str]], - transform: TaskTransform, + transform: SemanticSegmentationTransform, ): self.args = dataset_args self.filepaths = image_info @@ -50,20 +53,27 @@ def __init__( self.class_mapping = self.get_class_mapping() self.valid_classes = torch.tensor(list(self.class_mapping.keys())) - image_mode = Env.LIGHTLY_TRAIN_IMAGE_MODE.value - if image_mode not in ("RGB", "UNCHANGED"): + transform_args = transform.transform_args + assert isinstance(transform_args, SemanticSegmentationTransformArgs) + + image_mode = ( + None + if Env.LIGHTLY_TRAIN_IMAGE_MODE.value is None + else ImageMode(Env.LIGHTLY_TRAIN_IMAGE_MODE.value) + ) + if image_mode is None: + image_mode = ( + ImageMode.RGB + if transform_args.num_channels == 3 + else ImageMode.UNCHANGED + ) + + if image_mode not in (ImageMode.RGB, ImageMode.UNCHANGED): raise ValueError( - f'Invalid image mode: {Env.LIGHTLY_TRAIN_IMAGE_MODE.name}="{image_mode}". ' - "Supported modes are 'RGB' and 'UNCHANGED'." + f"Invalid image mode: '{image_mode}'. " + f"Supported modes are '{[ImageMode.RGB.value, ImageMode.UNCHANGED.value]}'." ) - # Convert string to enum value - if image_mode == "RGB": - self.image_mode = ImageMode.RGB - elif image_mode == "UNCHANGED": - self.image_mode = ImageMode.UNCHANGED - else: - # This should not happen due to the check above, but added for type safety - raise ValueError(f"Unexpected image mode: {image_mode}") + self.image_mode = image_mode def is_mask_valid(self, mask: Tensor) -> bool: # Check if at least one value in the mask is in the valid classes. diff --git a/src/lightly_train/_embedding/embedding_transform.py b/src/lightly_train/_embedding/embedding_transform.py index 9cef5453f..1906e05f7 100644 --- a/src/lightly_train/_embedding/embedding_transform.py +++ b/src/lightly_train/_embedding/embedding_transform.py @@ -21,8 +21,8 @@ class EmbeddingTransform: def __init__( self, image_size: int | tuple[int, int], - mean: tuple[float, float, float], - std: tuple[float, float, float], + mean: tuple[float, ...], + std: tuple[float, ...], ): if isinstance(image_size, int): image_size = (image_size, image_size) diff --git a/src/lightly_train/_env.py b/src/lightly_train/_env.py index d7b711963..f6e3162ba 100644 --- a/src/lightly_train/_env.py +++ b/src/lightly_train/_env.py @@ -101,9 +101,9 @@ class Env: ) # Mode in which images are loaded. This can be "RGB" to load images in RGB or # "UNCHANGED" to load images in their original format without any conversion. - LIGHTLY_TRAIN_IMAGE_MODE: EnvVar[str] = EnvVar( + LIGHTLY_TRAIN_IMAGE_MODE: EnvVar[str | None] = EnvVar( name="LIGHTLY_TRAIN_IMAGE_MODE", - _default="RGB", + _default=None, _type=str, ) LIGHTLY_TRAIN_MASK_DIR: EnvVar[Path | None] = EnvVar( diff --git a/src/lightly_train/_methods/densecl/densecl.py b/src/lightly_train/_methods/densecl/densecl.py index be5aed7e6..b390a2143 100644 --- a/src/lightly_train/_methods/densecl/densecl.py +++ b/src/lightly_train/_methods/densecl/densecl.py @@ -156,12 +156,14 @@ def __init__( optimizer_args: OptimizerArgs, embedding_model: EmbeddingModel, global_batch_size: int, + num_input_channels: int, ): super().__init__( method_args=method_args, optimizer_args=optimizer_args, embedding_model=embedding_model, global_batch_size=global_batch_size, + num_input_channels=num_input_channels, ) self.method_args = method_args self.query_encoder = DenseCLEncoder( diff --git a/src/lightly_train/_methods/densecl/densecl_transform.py b/src/lightly_train/_methods/densecl/densecl_transform.py index acbf27018..2adf39a33 100644 --- a/src/lightly_train/_methods/densecl/densecl_transform.py +++ b/src/lightly_train/_methods/densecl/densecl_transform.py @@ -8,6 +8,7 @@ from __future__ import annotations from pydantic import Field +from typing_extensions import Literal from lightly_train._transforms.transform import ( ChannelDropArgs, @@ -55,6 +56,7 @@ class DenseCLGaussianBlurArgs(GaussianBlurArgs): class DenseCLTransformArgs(MethodTransformArgs): image_size: tuple[int, int] = Field(default=(224, 224), strict=False) channel_drop: ChannelDropArgs | None = None + num_channels: int | Literal["auto"] = "auto" random_resize: DenseCLRandomResizeArgs | None = Field( default_factory=DenseCLRandomResizeArgs ) diff --git a/src/lightly_train/_methods/detcon/detcon.py b/src/lightly_train/_methods/detcon/detcon.py index 520ccb4ff..82885285e 100644 --- a/src/lightly_train/_methods/detcon/detcon.py +++ b/src/lightly_train/_methods/detcon/detcon.py @@ -151,12 +151,14 @@ def __init__( optimizer_args: OptimizerArgs, embedding_model: EmbeddingModel, global_batch_size: int, + num_input_channels: int, ) -> None: super().__init__( method_args=method_args, optimizer_args=optimizer_args, embedding_model=embedding_model, global_batch_size=global_batch_size, + num_input_channels=num_input_channels, ) self.method_args = method_args self.embedding_model = embedding_model @@ -268,12 +270,14 @@ def __init__( optimizer_args: OptimizerArgs, embedding_model: EmbeddingModel, global_batch_size: int, + num_input_channels: int, ) -> None: super().__init__( method_args=method_args, optimizer_args=optimizer_args, embedding_model=embedding_model, global_batch_size=global_batch_size, + num_input_channels=num_input_channels, ) self.method_args = method_args diff --git a/src/lightly_train/_methods/detcon/detcon_transform.py b/src/lightly_train/_methods/detcon/detcon_transform.py index f3c40de6b..d290b72c7 100644 --- a/src/lightly_train/_methods/detcon/detcon_transform.py +++ b/src/lightly_train/_methods/detcon/detcon_transform.py @@ -7,6 +7,8 @@ # from __future__ import annotations +from typing import Literal + from pydantic import Field from lightly_train._configs.config import PydanticConfig @@ -58,6 +60,7 @@ class DetConSView1TransformArgs(PydanticConfig): class DetConSTransformArgs(MethodTransformArgs): image_size: tuple[int, int] = Field(default=(224, 224), strict=False) channel_drop: ChannelDropArgs | None = None + num_channels: int | Literal["auto"] = "auto" random_resize: RandomResizeArgs | None = Field(default_factory=RandomResizeArgs) random_flip: RandomFlipArgs | None = Field(default_factory=RandomFlipArgs) random_rotation: RandomRotationArgs | None = None @@ -109,6 +112,7 @@ class DetConBView1TransformArgs(PydanticConfig): class DetConBTransformArgs(MethodTransformArgs): image_size: tuple[int, int] = Field(default=(224, 224), strict=False) channel_drop: ChannelDropArgs | None = None + num_channels: int | Literal["auto"] = "auto" random_resize: RandomResizeArgs | None = Field(default_factory=RandomResizeArgs) random_flip: RandomFlipArgs | None = Field(default_factory=RandomFlipArgs) random_rotation: RandomRotationArgs | None = None diff --git a/src/lightly_train/_methods/dino/dino.py b/src/lightly_train/_methods/dino/dino.py index 6f4a4538a..15e813983 100644 --- a/src/lightly_train/_methods/dino/dino.py +++ b/src/lightly_train/_methods/dino/dino.py @@ -167,12 +167,14 @@ def __init__( optimizer_args: OptimizerArgs, embedding_model: EmbeddingModel, global_batch_size: int, + num_input_channels: int, ): super().__init__( method_args=method_args, optimizer_args=optimizer_args, embedding_model=embedding_model, global_batch_size=global_batch_size, + num_input_channels=num_input_channels, ) self.method_args = method_args self.teacher_embedding_model = embedding_model diff --git a/src/lightly_train/_methods/dino/dino_transform.py b/src/lightly_train/_methods/dino/dino_transform.py index 8d5a69ae8..0302d34ba 100644 --- a/src/lightly_train/_methods/dino/dino_transform.py +++ b/src/lightly_train/_methods/dino/dino_transform.py @@ -7,6 +7,8 @@ # from __future__ import annotations +from typing import Literal + from pydantic import Field from lightly_train._configs.config import PydanticConfig @@ -99,6 +101,7 @@ class DINOTransformArgs(MethodTransformArgs): # https://github.com/facebookresearch/dino#resnet-50-and-other-convnets-trainings image_size: tuple[int, int] = Field(default=(224, 224), strict=False) channel_drop: ChannelDropArgs | None = None + num_channels: int | Literal["auto"] = "auto" random_resize: DINORandomResizeArgs | None = Field( default_factory=DINORandomResizeArgs ) diff --git a/src/lightly_train/_methods/dinov2/dinov2.py b/src/lightly_train/_methods/dinov2/dinov2.py index 807ac95f2..8bc29fddd 100644 --- a/src/lightly_train/_methods/dinov2/dinov2.py +++ b/src/lightly_train/_methods/dinov2/dinov2.py @@ -182,12 +182,14 @@ def __init__( optimizer_args: DINOv2AdamWViTArgs, embedding_model: EmbeddingModel, global_batch_size: int, + num_input_channels: int, ): super().__init__( method_args=method_args, optimizer_args=optimizer_args, embedding_model=embedding_model, global_batch_size=global_batch_size, + num_input_channels=num_input_channels, ) self.method_args = method_args diff --git a/src/lightly_train/_methods/distillation/distillation.py b/src/lightly_train/_methods/distillation/distillation.py index f9778767f..add7957c7 100644 --- a/src/lightly_train/_methods/distillation/distillation.py +++ b/src/lightly_train/_methods/distillation/distillation.py @@ -44,8 +44,14 @@ logger = logging.getLogger(__name__) -def get_teacher(teacher_name: str, teacher_weights: str | Path | None = None) -> Module: - wrapped_model = package_helpers.get_wrapped_model(model=teacher_name) +def get_teacher( + teacher_name: str, + num_input_channels: int, + teacher_weights: str | Path | None = None, +) -> Module: + wrapped_model = package_helpers.get_wrapped_model( + model=teacher_name, num_input_channels=num_input_channels + ) assert isinstance(wrapped_model, (DINOv2ViTModelWrapper, DINOv3ViTModelWrapper)) wrapped_model.make_teacher() teacher_embedding_model = wrapped_model.get_model() @@ -135,16 +141,20 @@ def __init__( optimizer_args: OptimizerArgs, embedding_model: EmbeddingModel, global_batch_size: int, + num_input_channels: int, ): super().__init__( method_args=method_args, optimizer_args=optimizer_args, embedding_model=embedding_model, global_batch_size=global_batch_size, + num_input_channels=num_input_channels, ) # Get the teacher model. self.teacher_embedding_model = get_teacher( - method_args.teacher, method_args.teacher_weights + method_args.teacher, + num_input_channels=num_input_channels, + teacher_weights=method_args.teacher_weights, ) # Store the student model. diff --git a/src/lightly_train/_methods/distillation/distillation_transform.py b/src/lightly_train/_methods/distillation/distillation_transform.py index c277e02d3..43296c10c 100644 --- a/src/lightly_train/_methods/distillation/distillation_transform.py +++ b/src/lightly_train/_methods/distillation/distillation_transform.py @@ -8,6 +8,7 @@ from __future__ import annotations from pydantic import Field +from typing_extensions import Literal from lightly_train._transforms.transform import ( ChannelDropArgs, @@ -54,6 +55,7 @@ class DistillationGaussianBlurArgs(GaussianBlurArgs): class DistillationTransformArgs(MethodTransformArgs): image_size: tuple[int, int] = Field(default=(224, 224), strict=False) channel_drop: ChannelDropArgs | None = None + num_channels: int | Literal["auto"] = "auto" random_resize: DistillationRandomResizeArgs | None = Field( default_factory=DistillationRandomResizeArgs ) diff --git a/src/lightly_train/_methods/distillationv2/distillationv2.py b/src/lightly_train/_methods/distillationv2/distillationv2.py index b5e794e48..ec1cbb3b6 100644 --- a/src/lightly_train/_methods/distillationv2/distillationv2.py +++ b/src/lightly_train/_methods/distillationv2/distillationv2.py @@ -43,6 +43,7 @@ def get_teacher( teacher_name: str, + num_input_channels: int, teacher_weights: str | Path | None = None, method_args: DistillationV2Args | None = None, ) -> Module: @@ -51,7 +52,7 @@ def get_teacher( model_args["weights"] = method_args.teacher_url wrapped_model = package_helpers.get_wrapped_model( - model=teacher_name, model_args=model_args + model=teacher_name, num_input_channels=num_input_channels, model_args=model_args ) assert isinstance(wrapped_model, (DINOv2ViTModelWrapper, DINOv3ViTModelWrapper)) wrapped_model.make_teacher() @@ -155,16 +156,21 @@ def __init__( optimizer_args: OptimizerArgs, embedding_model: EmbeddingModel, global_batch_size: int, + num_input_channels: int, ): super().__init__( method_args=method_args, optimizer_args=optimizer_args, embedding_model=embedding_model, global_batch_size=global_batch_size, + num_input_channels=num_input_channels, ) # Get the teacher model. self.teacher_embedding_model = get_teacher( - method_args.teacher, method_args.teacher_weights, method_args + teacher_name=method_args.teacher, + num_input_channels=num_input_channels, + teacher_weights=method_args.teacher_weights, + method_args=method_args, ) self.teacher_embedding_dim = ( method_args.n_teacher_blocks * self.teacher_embedding_model.embed_dim diff --git a/src/lightly_train/_methods/distillationv2/distillationv2_transform.py b/src/lightly_train/_methods/distillationv2/distillationv2_transform.py index ba01e00ba..449ad79e2 100644 --- a/src/lightly_train/_methods/distillationv2/distillationv2_transform.py +++ b/src/lightly_train/_methods/distillationv2/distillationv2_transform.py @@ -9,6 +9,8 @@ # Note: This file is identical (up to renaming) to src/lightly_train/_methods/distillation/distillation_transform.py from __future__ import annotations +from typing import Literal + from pydantic import Field from lightly_train._transforms.transform import ( @@ -56,6 +58,7 @@ class DistillationV2GaussianBlurArgs(GaussianBlurArgs): class DistillationV2TransformArgs(MethodTransformArgs): image_size: tuple[int, int] = Field(default=(224, 224), strict=False) channel_drop: ChannelDropArgs | None = None + num_channels: int | Literal["auto"] = "auto" random_resize: DistillationV2RandomResizeArgs | None = Field( default_factory=DistillationV2RandomResizeArgs ) diff --git a/src/lightly_train/_methods/method.py b/src/lightly_train/_methods/method.py index 238250c0a..ca3423f70 100644 --- a/src/lightly_train/_methods/method.py +++ b/src/lightly_train/_methods/method.py @@ -54,6 +54,7 @@ def __init__( optimizer_args: OptimizerArgs, embedding_model: EmbeddingModel, global_batch_size: int, + num_input_channels: int, ): super().__init__() self.global_batch_size = global_batch_size diff --git a/src/lightly_train/_methods/simclr/simclr.py b/src/lightly_train/_methods/simclr/simclr.py index c9dc60c28..74743d81d 100644 --- a/src/lightly_train/_methods/simclr/simclr.py +++ b/src/lightly_train/_methods/simclr/simclr.py @@ -52,12 +52,14 @@ def __init__( optimizer_args: OptimizerArgs, embedding_model: EmbeddingModel, global_batch_size: int, + num_input_channels: int, ): super().__init__( method_args=method_args, optimizer_args=optimizer_args, embedding_model=embedding_model, global_batch_size=global_batch_size, + num_input_channels=num_input_channels, ) self.method_args = method_args self.embedding_model = embedding_model diff --git a/src/lightly_train/_methods/simclr/simclr_transform.py b/src/lightly_train/_methods/simclr/simclr_transform.py index 399721ba7..479c754ec 100644 --- a/src/lightly_train/_methods/simclr/simclr_transform.py +++ b/src/lightly_train/_methods/simclr/simclr_transform.py @@ -8,6 +8,7 @@ from __future__ import annotations from pydantic import Field +from typing_extensions import Literal from lightly_train._transforms.transform import ( ChannelDropArgs, @@ -50,6 +51,7 @@ class SimCLRGaussianBlurArgs(GaussianBlurArgs): class SimCLRTransformArgs(MethodTransformArgs): image_size: tuple[int, int] = Field(default=(224, 224), strict=False) channel_drop: ChannelDropArgs | None = None + num_channels: int | Literal["auto"] = "auto" random_resize: RandomResizeArgs | None = Field(default_factory=RandomResizeArgs) random_flip: RandomFlipArgs | None = Field(default_factory=RandomFlipArgs) random_rotation: RandomRotationArgs | None = None diff --git a/src/lightly_train/_models/_model_helpers.py b/src/lightly_train/_models/_model_helpers.py new file mode 100644 index 000000000..14c47f023 --- /dev/null +++ b/src/lightly_train/_models/_model_helpers.py @@ -0,0 +1,57 @@ +# +# Copyright (c) Lightly AG and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +from __future__ import annotations + +import logging +from typing import Any + +import torch +from torch.nn import Module + +logger = logging.getLogger(__name__) + + +def patch_embed_adjust_input_channels_hook( + module: Module, + state_dict: dict[str, Any], + prefix: str, + *args: Any, + **kwargs: Any, +) -> None: + """Hook to adjust the number of channels in the state dict to the number of + channels in the module. + """ + + proj_weight_key = f"{prefix}proj.weight" + proj_weight = state_dict.get(proj_weight_key) + if proj_weight is not None: + weights_in_chans = proj_weight.shape[1] + if weights_in_chans > module.in_chans: + # Drop last channels + logger.info( + f"Loading pretrained weights with {weights_in_chans} input channels, " + f"but model has {module.in_chans} input channels. Keeping only the " + f"first {module.in_chans} channels of the pretrained weights." + ) + proj_weight = proj_weight[:, : module.in_chans, :, :] + elif weights_in_chans < module.in_chans: + # Repeat channels to initialize extra channels + logger.info( + f"Loading pretrained weights with {weights_in_chans} input channels, " + f"but model has {module.in_chans} input channels. Repeating the " + "channels of the pretrained weights to initialize the extra " + "channels." + ) + repeat_times = module.in_chans // weights_in_chans + remainder = module.in_chans % weights_in_chans + proj_weight = proj_weight.repeat(1, repeat_times, 1, 1) + if remainder > 0: + proj_weight = torch.cat( + [proj_weight, proj_weight[:, :remainder, :, :]], dim=1 + ) + state_dict[proj_weight_key] = proj_weight diff --git a/src/lightly_train/_models/dinov2_vit/dinov2_vit_package.py b/src/lightly_train/_models/dinov2_vit/dinov2_vit_package.py index 1ba91f2ac..497cf1e0f 100644 --- a/src/lightly_train/_models/dinov2_vit/dinov2_vit_package.py +++ b/src/lightly_train/_models/dinov2_vit/dinov2_vit_package.py @@ -77,7 +77,10 @@ def parse_model_name(cls, model_name: str) -> str: @classmethod def get_model( - cls, model_name: str, model_args: dict[str, Any] | None = None + cls, + model_name: str, + num_input_channels: int = 3, + model_args: dict[str, Any] | None = None, ) -> DinoVisionTransformer: """ Get a DINOv2 ViT model by name. Here the student version is build. @@ -114,6 +117,7 @@ def get_model( interpolate_antialias=cfg.student.interpolate_antialias, drop_path_rate=cfg.student.drop_path_rate, drop_path_uniform=cfg.student.drop_path_uniform, + in_chans=num_input_channels, ) kwargs.update(model_args or {}) diff --git a/src/lightly_train/_models/dinov2_vit/dinov2_vit_src/layers/patch_embed.py b/src/lightly_train/_models/dinov2_vit/dinov2_vit_src/layers/patch_embed.py index 7b41f3a9d..7d8ea3a53 100644 --- a/src/lightly_train/_models/dinov2_vit/dinov2_vit_src/layers/patch_embed.py +++ b/src/lightly_train/_models/dinov2_vit/dinov2_vit_src/layers/patch_embed.py @@ -9,6 +9,10 @@ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py +# Modifications Copyright 2025 Lightly AG: +# - Modified load_state_dict to handle different number of input channels + + import math from typing import Callable, Optional, Tuple, Union @@ -16,6 +20,8 @@ import torch.nn.functional as F from torch import Tensor +from lightly_train._models import _model_helpers + def make_2tuple(x): if isinstance(x, tuple): @@ -71,6 +77,16 @@ def __init__( ) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + if hasattr(self, "register_load_state_dict_pre_hook"): + self.register_load_state_dict_pre_hook( + _model_helpers.patch_embed_adjust_input_channels_hook + ) + else: + # Backwards compatibility for PyTorch <= 2.4 + self._register_load_state_dict_pre_hook( + _model_helpers.patch_embed_adjust_input_channels_hook, with_module=True + ) + def forward(self, x: Tensor) -> Tuple[Tensor, int, int]: _, _, H, W = x.shape patch_H, patch_W = self.patch_size diff --git a/src/lightly_train/_models/dinov3/dinov3_package.py b/src/lightly_train/_models/dinov3/dinov3_package.py index 19ba90fe5..9bb10cdc9 100644 --- a/src/lightly_train/_models/dinov3/dinov3_package.py +++ b/src/lightly_train/_models/dinov3/dinov3_package.py @@ -74,13 +74,18 @@ def parse_model_name(cls, model_name: str) -> str: @classmethod def get_model( - cls, model_name: str, model_args: dict[str, Any] | None = None + cls, + model_name: str, + num_input_channels: int = 3, + model_args: dict[str, Any] | None = None, ) -> DinoVisionTransformer: """ Get a DINOv3 ViT model by name. Here the student version is build. """ - model_args = {} if model_args is None else model_args - model = MODEL_NAME_TO_GETTER[model_name](**model_args) + args: dict[str, Any] = {"in_chans": num_input_channels} + if model_args is not None: + args.update(model_args) + model = MODEL_NAME_TO_GETTER[model_name](**args) assert isinstance(model, DinoVisionTransformer) return model diff --git a/src/lightly_train/_models/dinov3/dinov3_src/hub/backbones.py b/src/lightly_train/_models/dinov3/dinov3_src/hub/backbones.py index 6daeac3de..3d6906e9d 100644 --- a/src/lightly_train/_models/dinov3/dinov3_src/hub/backbones.py +++ b/src/lightly_train/_models/dinov3/dinov3_src/hub/backbones.py @@ -217,6 +217,7 @@ def dinov3_vits16( pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD1689M, check_hash: bool = False, + in_chans: int = 3, **kwargs, ): if "hash" not in kwargs: @@ -225,7 +226,7 @@ def dinov3_vits16( return _make_dinov3_vit( img_size=224, patch_size=16, - in_chans=3, + in_chans=in_chans, pos_embed_rope_base=100, pos_embed_rope_normalize_coords="separate", pos_embed_rope_rescale_coords=2, @@ -256,6 +257,7 @@ def dinov3_vits16plus( pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD1689M, check_hash: bool = False, + in_chans: int = 3, **kwargs, ): if "hash" not in kwargs: @@ -264,7 +266,7 @@ def dinov3_vits16plus( return _make_dinov3_vit( img_size=224, patch_size=16, - in_chans=3, + in_chans=in_chans, pos_embed_rope_base=100, pos_embed_rope_normalize_coords="separate", pos_embed_rope_rescale_coords=2, @@ -295,6 +297,7 @@ def dinov3_vitb16( pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD1689M, check_hash: bool = False, + in_chans: int = 3, **kwargs, ): if "hash" not in kwargs: @@ -303,7 +306,7 @@ def dinov3_vitb16( return _make_dinov3_vit( img_size=224, patch_size=16, - in_chans=3, + in_chans=in_chans, pos_embed_rope_base=100, pos_embed_rope_normalize_coords="separate", pos_embed_rope_rescale_coords=2, @@ -334,6 +337,7 @@ def dinov3_vitl16( pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD1689M, check_hash: bool = False, + in_chans: int = 3, **kwargs, ): untie_global_and_local_cls_norm = False @@ -360,7 +364,7 @@ def dinov3_vitl16( return _make_dinov3_vit( img_size=224, patch_size=16, - in_chans=3, + in_chans=in_chans, pos_embed_rope_base=100, pos_embed_rope_normalize_coords="separate", pos_embed_rope_rescale_coords=2, @@ -392,6 +396,7 @@ def dinov3_vitl16plus( pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD1689M, check_hash: bool = False, + in_chans: int = 3, **kwargs, ): if "hash" not in kwargs: @@ -400,7 +405,7 @@ def dinov3_vitl16plus( return _make_dinov3_vit( img_size=224, patch_size=16, - in_chans=3, + in_chans=in_chans, pos_embed_rope_base=100, pos_embed_rope_normalize_coords="separate", pos_embed_rope_rescale_coords=2, @@ -431,6 +436,7 @@ def dinov3_vith16plus( pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD1689M, check_hash: bool = False, + in_chans: int = 3, **kwargs, ): if "hash" not in kwargs: @@ -439,7 +445,7 @@ def dinov3_vith16plus( return _make_dinov3_vit( img_size=224, patch_size=16, - in_chans=3, + in_chans=in_chans, pos_embed_rope_base=100, pos_embed_rope_normalize_coords="separate", pos_embed_rope_rescale_coords=2, @@ -470,6 +476,7 @@ def dinov3_vit7b16( pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD1689M, check_hash: bool = False, + in_chans: int = 3, **kwargs, ): if weights == Weights.LVD1689M: @@ -483,7 +490,7 @@ def dinov3_vit7b16( return _make_dinov3_vit( img_size=224, patch_size=16, - in_chans=3, + in_chans=in_chans, pos_embed_rope_base=100, pos_embed_rope_normalize_coords="separate", pos_embed_rope_rescale_coords=2, @@ -514,6 +521,7 @@ def dinov3_convnext_tiny( *, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD1689M, + in_chans: int = 3, **kwargs, ): _hash_convnext = "21b726bb" @@ -525,7 +533,7 @@ def dinov3_convnext_tiny( size_dict = convnext_sizes["tiny"] model = _make_dinov3_convnext( - in_chans=3, + in_chans=in_chans, depths=size_dict["depths"], dims=size_dict["dims"], compact_arch_name="convnext_tiny", @@ -544,6 +552,7 @@ def dinov3_convnext_small( *, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD1689M, + in_chans: int = 3, **kwargs, ): _hash_convnext = "296db49d" @@ -555,7 +564,7 @@ def dinov3_convnext_small( size_dict = convnext_sizes["small"] model = _make_dinov3_convnext( - in_chans=3, + in_chans=in_chans, depths=size_dict["depths"], dims=size_dict["dims"], compact_arch_name="convnext_small", @@ -574,6 +583,7 @@ def dinov3_convnext_base( *, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD1689M, + in_chans: int = 3, **kwargs, ): _hash_convnext = "801f2ba9" @@ -585,7 +595,7 @@ def dinov3_convnext_base( size_dict = convnext_sizes["base"] model = _make_dinov3_convnext( - in_chans=3, + in_chans=in_chans, depths=size_dict["depths"], dims=size_dict["dims"], compact_arch_name="convnext_base", @@ -604,6 +614,7 @@ def dinov3_convnext_large( *, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD1689M, + in_chans: int = 3, **kwargs, ): _hash_convnext = "61fa432d" @@ -615,7 +626,7 @@ def dinov3_convnext_large( size_dict = convnext_sizes["large"] model = _make_dinov3_convnext( - in_chans=3, + in_chans=in_chans, depths=size_dict["depths"], dims=size_dict["dims"], compact_arch_name="convnext_large", diff --git a/src/lightly_train/_models/dinov3/dinov3_src/layers/patch_embed.py b/src/lightly_train/_models/dinov3/dinov3_src/layers/patch_embed.py index aa6223dfe..8562f4676 100644 --- a/src/lightly_train/_models/dinov3/dinov3_src/layers/patch_embed.py +++ b/src/lightly_train/_models/dinov3/dinov3_src/layers/patch_embed.py @@ -4,6 +4,9 @@ # This software may be used and distributed in accordance with # the terms of the DINOv3 License Agreement.# +# Modifications Copyright 2025 Lightly AG: +# - Modified load_state_dict to handle different number of input channels + from __future__ import annotations import math @@ -11,6 +14,8 @@ from torch import Tensor, nn +from lightly_train._models import _model_helpers + def make_2tuple(x): if isinstance(x, tuple): @@ -66,6 +71,16 @@ def __init__( ) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + if hasattr(self, "register_load_state_dict_pre_hook"): + self.register_load_state_dict_pre_hook( + _model_helpers.patch_embed_adjust_input_channels_hook + ) + else: + # Backwards compatibility for PyTorch <= 2.4 + self._register_load_state_dict_pre_hook( + _model_helpers.patch_embed_adjust_input_channels_hook, with_module=True + ) + def forward(self, x: Tensor) -> Tensor: _, _, H, W = x.shape # patch_H, patch_W = self.patch_size diff --git a/src/lightly_train/_models/package.py b/src/lightly_train/_models/package.py index d62b75cdc..0ea9e4b10 100644 --- a/src/lightly_train/_models/package.py +++ b/src/lightly_train/_models/package.py @@ -50,7 +50,10 @@ def list_model_names(cls) -> list[str]: @classmethod @abstractmethod def get_model( - cls, model_name: str, model_args: dict[str, Any] | None = None + cls, + model_name: str, + num_input_channels: int = 3, + model_args: dict[str, Any] | None = None, ) -> PackageModel: """Get the underlying model of the package by its name.""" ... diff --git a/src/lightly_train/_models/package_helpers.py b/src/lightly_train/_models/package_helpers.py index a031307c6..1471d80d2 100644 --- a/src/lightly_train/_models/package_helpers.py +++ b/src/lightly_train/_models/package_helpers.py @@ -70,7 +70,9 @@ def list_model_names() -> list[str]: def get_wrapped_model( - model: str | Module | ModelWrapper, model_args: dict[str, Any] | None = None + model: str | Module | ModelWrapper, + num_input_channels: int, + model_args: dict[str, Any] | None = None, ) -> ModelWrapper: """Returns a wrapped model instance given a model name or instance.""" if isinstance(model, ModelWrapper): @@ -80,7 +82,9 @@ def get_wrapped_model( if isinstance(model, str): package_name, model_name = parse_model_name(model) package = get_package(package_name) - model = package.get_model(model_name, model_args=model_args) + model = package.get_model( + model_name, num_input_channels=num_input_channels, model_args=model_args + ) else: package = get_package_from_model( model, include_custom=False, fallback_custom=False diff --git a/src/lightly_train/_models/rfdetr/rfdetr_package.py b/src/lightly_train/_models/rfdetr/rfdetr_package.py index f081cb6e7..3fc5e3141 100644 --- a/src/lightly_train/_models/rfdetr/rfdetr_package.py +++ b/src/lightly_train/_models/rfdetr/rfdetr_package.py @@ -52,7 +52,10 @@ def is_supported_model(cls, model: RFDETR | ModelWrapper | Any) -> bool: @classmethod def get_model( - cls, model_name: str, model_args: dict[str, Any] | None = None + cls, + model_name: str, + num_input_channels: int = 3, + model_args: dict[str, Any] | None = None, ) -> RFDETR: try: from rfdetr import RFDETRBase, RFDETRLarge @@ -61,6 +64,11 @@ def get_model( raise ValueError( f"Cannot create model '{model_name}' because rfdetr is not installed." ) + if num_input_channels != 3: + raise ValueError( + f"RFDETR models only support 3 input channels, but got " + f"{num_input_channels}." + ) args = {} if model_args is None else model_args.copy() # Remove these arguments so that get_model() only returns the full model diff --git a/src/lightly_train/_models/super_gradients/super_gradients_package.py b/src/lightly_train/_models/super_gradients/super_gradients_package.py index 07740de62..f6c0da9e5 100644 --- a/src/lightly_train/_models/super_gradients/super_gradients_package.py +++ b/src/lightly_train/_models/super_gradients/super_gradients_package.py @@ -70,7 +70,10 @@ def is_supported_model_cls(cls, model_cls: type[Module]) -> bool: @classmethod def get_model( - cls, model_name: str, model_args: dict[str, Any] | None = None + cls, + model_name: str, + num_input_channels: int = 3, + model_args: dict[str, Any] | None = None, ) -> Module: try: from super_gradients.training import models @@ -79,6 +82,11 @@ def get_model( f"Cannot create model '{model_name}' because '{cls.name}' is not " "installed." ) + if num_input_channels != 3: + raise ValueError( + f"SuperGradients models only support 3 input channels, but got " + f"{num_input_channels}." + ) args = dict(num_classes=10) if model_args is not None: args.update(model_args) diff --git a/src/lightly_train/_models/timm/timm_package.py b/src/lightly_train/_models/timm/timm_package.py index 17da5870a..49276dcc2 100644 --- a/src/lightly_train/_models/timm/timm_package.py +++ b/src/lightly_train/_models/timm/timm_package.py @@ -45,7 +45,10 @@ def is_supported_model(cls, model: Module | ModelWrapper | Any) -> bool: @classmethod def get_model( - cls, model_name: str, model_args: dict[str, Any] | None = None + cls, + model_name: str, + num_input_channels: int = 3, + model_args: dict[str, Any] | None = None, ) -> Module: try: import timm @@ -53,7 +56,7 @@ def get_model( raise ValueError( f"Cannot create model '{model_name}' because timm is not installed." ) - args = dict(pretrained=False) + args = dict(pretrained=False, in_chans=num_input_channels) # vit and eva models have dynamic_img_size defaulting to False, which would not allow inputs with varying image sizes, e.g., for DINO if ( model_name.startswith("vit") diff --git a/src/lightly_train/_models/torchvision/torchvision_package.py b/src/lightly_train/_models/torchvision/torchvision_package.py index 8e7e1697d..6a11c86c2 100644 --- a/src/lightly_train/_models/torchvision/torchvision_package.py +++ b/src/lightly_train/_models/torchvision/torchvision_package.py @@ -58,8 +58,16 @@ def is_supported_model(cls, model: Module | ModelWrapper | Any) -> bool: @classmethod def get_model( - cls, model_name: str, model_args: dict[str, Any] | None = None + cls, + model_name: str, + num_input_channels: int = 3, + model_args: dict[str, Any] | None = None, ) -> Module: + if num_input_channels != 3: + raise ValueError( + f"Torchvision models only support 3 input channels, but got " + f"{num_input_channels}." + ) args = dict() if model_args is not None: args.update(model_args) diff --git a/src/lightly_train/_models/ultralytics/ultralytics_package.py b/src/lightly_train/_models/ultralytics/ultralytics_package.py index 506bd4da4..d6dd00025 100644 --- a/src/lightly_train/_models/ultralytics/ultralytics_package.py +++ b/src/lightly_train/_models/ultralytics/ultralytics_package.py @@ -84,7 +84,10 @@ def is_supported_model(cls, model: Module | ModelWrapper) -> bool: @classmethod def get_model( - cls, model_name: str, model_args: dict[str, Any] | None = None + cls, + model_name: str, + num_input_channels: int = 3, + model_args: dict[str, Any] | None = None, ) -> Module: try: from ultralytics import YOLO @@ -93,6 +96,11 @@ def get_model( f"Cannot create model '{model_name}' because '{cls.name}' is not " "installed." ) + if num_input_channels != 3: + raise ValueError( + f"Ultralytics models only support 3 input channels, but got " + f"{num_input_channels}." + ) args = {} if model_args is None else model_args model: Module = YOLO(model=model_name, **args) return model diff --git a/src/lightly_train/_plot.py b/src/lightly_train/_plot.py index 6ddfb7650..459a53be9 100644 --- a/src/lightly_train/_plot.py +++ b/src/lightly_train/_plot.py @@ -60,7 +60,7 @@ def plot_example_augmentations(train_batch: Batch, max_examples: int = 10) -> PI :, x_start:x_end, y_start:y_end, - ] = image_tensor.cpu() + ] = image_tensor[:3].cpu() # Take only first 3 channels. # Note: Getting the normalization specific to the method is not trivial, # as it depends on the transform. See diff --git a/src/lightly_train/_task_models/dinov2_eomt_semantic_segmentation/task_model.py b/src/lightly_train/_task_models/dinov2_eomt_semantic_segmentation/task_model.py index 07bf90521..4afecb174 100644 --- a/src/lightly_train/_task_models/dinov2_eomt_semantic_segmentation/task_model.py +++ b/src/lightly_train/_task_models/dinov2_eomt_semantic_segmentation/task_model.py @@ -45,7 +45,7 @@ def __init__( classes: dict[int, str], class_ignore_index: int | None, image_size: tuple[int, int], - image_normalize: dict[str, float], + image_normalize: dict[str, tuple[float, ...]], num_queries: int, num_joint_blocks: int, backbone_weights: PathLike | None = None, @@ -110,6 +110,7 @@ def __init__( # Disable drop path by default. backbone_model_args = { "drop_path_rate": 0.0, + "in_chans": len(self.image_normalize["mean"]), } if backbone_args is not None: backbone_model_args.update(backbone_args) diff --git a/src/lightly_train/_task_models/dinov2_eomt_semantic_segmentation/transforms.py b/src/lightly_train/_task_models/dinov2_eomt_semantic_segmentation/transforms.py index 1f3bca86f..f433e7346 100644 --- a/src/lightly_train/_task_models/dinov2_eomt_semantic_segmentation/transforms.py +++ b/src/lightly_train/_task_models/dinov2_eomt_semantic_segmentation/transforms.py @@ -69,9 +69,10 @@ class DINOv2EoMTSemanticSegmentationTrainTransformArgs( image_size: tuple[int, int] = (518, 518) channel_drop: ChannelDropArgs | None = None + num_channels: int | Literal["auto"] = "auto" normalize: NormalizeArgs = Field(default_factory=NormalizeArgs) - random_flip: RandomFlipArgs = Field(default_factory=RandomFlipArgs) - color_jitter: DINOv2EoMTSemanticSegmentationColorJitterArgs = Field( + random_flip: RandomFlipArgs | None = Field(default_factory=RandomFlipArgs) + color_jitter: DINOv2EoMTSemanticSegmentationColorJitterArgs | None = Field( default_factory=DINOv2EoMTSemanticSegmentationColorJitterArgs ) scale_jitter: ScaleJitterArgs | None = Field( @@ -90,6 +91,7 @@ class DINOv2EoMTSemanticSegmentationValTransformArgs(SemanticSegmentationTransfo image_size: tuple[int, int] = (518, 518) channel_drop: ChannelDropArgs | None = None + num_channels: int | Literal["auto"] = "auto" normalize: NormalizeArgs = Field(default_factory=NormalizeArgs) random_flip: RandomFlipArgs | None = None color_jitter: ColorJitterArgs | None = None diff --git a/src/lightly_train/_task_models/dinov2_linear_semantic_segmentation/task_model.py b/src/lightly_train/_task_models/dinov2_linear_semantic_segmentation/task_model.py index 54ba473c2..230ab5656 100644 --- a/src/lightly_train/_task_models/dinov2_linear_semantic_segmentation/task_model.py +++ b/src/lightly_train/_task_models/dinov2_linear_semantic_segmentation/task_model.py @@ -41,10 +41,10 @@ def __init__( classes: dict[int, str], class_ignore_index: int | None, backbone_freeze: bool, + image_size: tuple[int, int], + image_normalize: dict[str, tuple[float, ...]], backbone_weights: PathLike | None = None, backbone_args: dict[str, Any] | None = None, - image_size: tuple[int, int], - image_normalize: dict[str, float], ) -> None: """ Args: @@ -58,15 +58,15 @@ def __init__( The class ID assigned to pixels that do not belong to any of the classes in `classes`. If None, the model will not ignore any classes and always assign a class to each pixel. + image_size: + The size to resize images to during inference. Default is (518, 518). + image_normalize: + The normalization parameters for images. Default uses ImageNet stats. backbone_weights: The path to the DINOv2 backbone weights. The weights must be exported using LightlyTrain. backbone_args: Additional arguments to pass to the DINOv2 backbone. - image_size: - The size to resize images to during inference. Default is (518, 518). - image_normalize: - The normalization parameters for images. Default uses ImageNet stats. """ super().__init__(locals(), ignore_args={"backbone_weights"}) parsed_name = self.parse_model_name(model_name=model_name) @@ -97,6 +97,7 @@ def __init__( # Disable drop path by default. args = { "drop_path_rate": 0.0, + "in_chans": len(self.image_normalize["mean"]), } if backbone_args is not None: args.update(backbone_args) diff --git a/src/lightly_train/_task_models/dinov2_linear_semantic_segmentation/transforms.py b/src/lightly_train/_task_models/dinov2_linear_semantic_segmentation/transforms.py index 296f67cf4..2e56525e0 100644 --- a/src/lightly_train/_task_models/dinov2_linear_semantic_segmentation/transforms.py +++ b/src/lightly_train/_task_models/dinov2_linear_semantic_segmentation/transforms.py @@ -69,9 +69,10 @@ class DINOv2LinearSemanticSegmentationTrainTransformArgs( image_size: tuple[int, int] = (518, 518) channel_drop: ChannelDropArgs | None = None + num_channels: int | Literal["auto"] = "auto" normalize: NormalizeArgs = Field(default_factory=NormalizeArgs) - random_flip: RandomFlipArgs = Field(default_factory=RandomFlipArgs) - color_jitter: DINOv2LinearSemanticSegmentationColorJitterArgs = Field( + random_flip: RandomFlipArgs | None = Field(default_factory=RandomFlipArgs) + color_jitter: DINOv2LinearSemanticSegmentationColorJitterArgs | None = Field( default_factory=DINOv2LinearSemanticSegmentationColorJitterArgs ) scale_jitter: ScaleJitterArgs | None = Field( @@ -92,6 +93,7 @@ class DINOv2LinearSemanticSegmentationValTransformArgs( image_size: tuple[int, int] = (518, 518) channel_drop: ChannelDropArgs | None = None + num_channels: int | Literal["auto"] = "auto" normalize: NormalizeArgs = Field(default_factory=NormalizeArgs) random_flip: RandomFlipArgs | None = None color_jitter: ColorJitterArgs | None = None diff --git a/src/lightly_train/_task_models/dinov3_eomt_semantic_segmentation/task_model.py b/src/lightly_train/_task_models/dinov3_eomt_semantic_segmentation/task_model.py index 1dd206e8c..c168d54b6 100644 --- a/src/lightly_train/_task_models/dinov3_eomt_semantic_segmentation/task_model.py +++ b/src/lightly_train/_task_models/dinov3_eomt_semantic_segmentation/task_model.py @@ -47,7 +47,7 @@ def __init__( classes: dict[int, str], class_ignore_index: int | None, image_size: tuple[int, int], - image_normalize: dict[str, float], + image_normalize: dict[str, tuple[float, ...]], num_queries: int, num_joint_blocks: int, backbone_url: str | None = None, @@ -115,7 +115,9 @@ def __init__( # NOTE(Guarin, 08/25): We don't set drop_path_rate=0 here because it is already # set by DINOv3. - backbone_model_args: dict[str, Any] = {} + backbone_model_args: dict[str, Any] = { + "in_chans": len(self.image_normalize["mean"]), + } if backbone_url is not None: backbone_model_args["weights"] = backbone_url else: diff --git a/src/lightly_train/_task_models/dinov3_eomt_semantic_segmentation/transforms.py b/src/lightly_train/_task_models/dinov3_eomt_semantic_segmentation/transforms.py index dbb828f7b..8d0954373 100644 --- a/src/lightly_train/_task_models/dinov3_eomt_semantic_segmentation/transforms.py +++ b/src/lightly_train/_task_models/dinov3_eomt_semantic_segmentation/transforms.py @@ -70,9 +70,10 @@ class DINOv3EoMTSemanticSegmentationTrainTransformArgs( # TODO(Guarin, 08/25): Check if we should change default to 512. image_size: tuple[int, int] = (518, 518) channel_drop: ChannelDropArgs | None = None + num_channels: int | Literal["auto"] = "auto" normalize: NormalizeArgs = Field(default_factory=NormalizeArgs) - random_flip: RandomFlipArgs = Field(default_factory=RandomFlipArgs) - color_jitter: DINOv3EoMTSemanticSegmentationColorJitterArgs = Field( + random_flip: RandomFlipArgs | None = Field(default_factory=RandomFlipArgs) + color_jitter: DINOv3EoMTSemanticSegmentationColorJitterArgs | None = Field( default_factory=DINOv3EoMTSemanticSegmentationColorJitterArgs ) scale_jitter: ScaleJitterArgs | None = Field( @@ -91,6 +92,7 @@ class DINOv3EoMTSemanticSegmentationValTransformArgs(SemanticSegmentationTransfo image_size: tuple[int, int] = (518, 518) channel_drop: ChannelDropArgs | None = None + num_channels: int | Literal["auto"] = "auto" normalize: NormalizeArgs = Field(default_factory=NormalizeArgs) random_flip: RandomFlipArgs | None = None color_jitter: ColorJitterArgs | None = None diff --git a/src/lightly_train/_transforms/semantic_segmentation_transform.py b/src/lightly_train/_transforms/semantic_segmentation_transform.py index a7ca17e6f..65a4e3ae8 100644 --- a/src/lightly_train/_transforms/semantic_segmentation_transform.py +++ b/src/lightly_train/_transforms/semantic_segmentation_transform.py @@ -8,6 +8,9 @@ from __future__ import annotations +import logging +from typing import Literal + import numpy as np from albumentations import ( BasicTransform, @@ -26,6 +29,7 @@ from typing_extensions import NotRequired from lightly_train._configs.validate import no_auto +from lightly_train._transforms.channel_drop import ChannelDrop from lightly_train._transforms.task_transform import ( TaskTransform, TaskTransformArgs, @@ -33,6 +37,7 @@ TaskTransformOutput, ) from lightly_train._transforms.transform import ( + ChannelDropArgs, ColorJitterArgs, NormalizeArgs, RandomCropArgs, @@ -41,6 +46,8 @@ SmallestMaxSizeArgs, ) +logger = logging.getLogger(__name__) + class SemanticSegmentationTransformInput(TaskTransformInput): image: NDArray[np.uint8] @@ -55,6 +62,8 @@ class SemanticSegmentationTransformOutput(TaskTransformOutput): class SemanticSegmentationTransformArgs(TaskTransformArgs): ignore_index: int image_size: tuple[int, int] + channel_drop: ChannelDropArgs | None + num_channels: int | Literal["auto"] normalize: NormalizeArgs random_flip: RandomFlipArgs | None color_jitter: ColorJitterArgs | None @@ -63,15 +72,56 @@ class SemanticSegmentationTransformArgs(TaskTransformArgs): random_crop: RandomCropArgs | None def resolve_auto(self) -> None: + if self.num_channels == "auto": + if self.channel_drop is not None: + self.num_channels = self.channel_drop.num_channels_keep + else: + self.num_channels = len(self.normalize.mean) + height, width = self.image_size for field_name in self.__class__.model_fields: field = getattr(self, field_name) if hasattr(field, "resolve_auto"): field.resolve_auto(height=height, width=width) + def resolve_incompatible(self) -> None: + # Adjust normalization mean and std to match num_channels. + if len(self.normalize.mean) != no_auto(self.num_channels): + logger.debug( + "Adjusting mean of normalize transform to match num_channels. " + f"num_channels is {self.num_channels} but " + f"normalize.mean has length {len(self.normalize.mean)}." + ) + # Repeat the values until they match num_channels. + self.normalize.mean = tuple( + self.normalize.mean[i % len(self.normalize.mean)] + for i in range(no_auto(self.num_channels)) + ) + if len(self.normalize.std) != no_auto(self.num_channels): + logger.debug( + "Adjusting std of normalize transform to match num_channels. " + f"num_channels is {self.num_channels} but " + f"normalize.std has length {len(self.normalize.std)}." + ) + # Repeat the values until they match num_channels. + self.normalize.std = tuple( + self.normalize.std[i % len(self.normalize.std)] + for i in range(no_auto(self.num_channels)) + ) + + # Disable color jitter if necessary. + if self.color_jitter is not None and no_auto(self.num_channels) != 3: + logger.debug( + "Disabling color jitter transform as it only supports 3-channel " + f"images but num_channels is {self.num_channels}." + ) + self.color_jitter = None + class SemanticSegmentationTransform(TaskTransform): - transform_args_cls: type[SemanticSegmentationTransformArgs] + transform_args_cls: type[SemanticSegmentationTransformArgs] = ( + SemanticSegmentationTransformArgs + ) def __init__(self, transform_args: SemanticSegmentationTransformArgs) -> None: super().__init__(transform_args) @@ -79,6 +129,14 @@ def __init__(self, transform_args: SemanticSegmentationTransformArgs) -> None: # Initialize the list of transforms to apply. transform: list[BasicTransform] = [] + if transform_args.channel_drop is not None: + transform += [ + ChannelDrop( + num_channels_keep=transform_args.channel_drop.num_channels_keep, + weight_drop=transform_args.channel_drop.weight_drop, + ) + ] + if transform_args.scale_jitter is not None: # This follows recommendation on how to replace torchvision ScaleJitter with # albumentations: https://albumentations.ai/docs/torchvision-kornia2albumentations/ diff --git a/src/lightly_train/_transforms/task_transform.py b/src/lightly_train/_transforms/task_transform.py index 7e89d3e95..43cf49310 100644 --- a/src/lightly_train/_transforms/task_transform.py +++ b/src/lightly_train/_transforms/task_transform.py @@ -24,6 +24,11 @@ class TaskTransformOutput(TypedDict): class TaskTransformArgs(PydanticConfig): def resolve_auto(self) -> None: + """Resolve any arguments set to "auto".""" + pass + + def resolve_incompatible(self) -> None: + """Resolve any incompatible arguments.""" pass model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/src/lightly_train/_transforms/transform.py b/src/lightly_train/_transforms/transform.py index ad316680b..044ea8e97 100644 --- a/src/lightly_train/_transforms/transform.py +++ b/src/lightly_train/_transforms/transform.py @@ -7,6 +7,7 @@ # from __future__ import annotations +import logging from collections.abc import Sequence from typing import ( Literal, @@ -19,8 +20,11 @@ from pydantic import Field from lightly_train._configs.config import PydanticConfig +from lightly_train._configs.validate import no_auto from lightly_train.types import TransformInput, TransformOutput +logger = logging.getLogger(__name__) + class ChannelDropArgs(PydanticConfig): num_channels_keep: int @@ -92,7 +96,7 @@ class SolarizeArgs(PydanticConfig): class NormalizeArgs(PydanticConfig): # Strict is set to False because OmegaConf does not support parsing tuples from the # CLI. Setting strict to False allows Pydantic to convert lists to tuples. - mean: tuple[float, float, float] = Field( + mean: tuple[float, ...] = Field( default=( IMAGENET_NORMALIZE["mean"][0], IMAGENET_NORMALIZE["mean"][1], @@ -100,7 +104,7 @@ class NormalizeArgs(PydanticConfig): ), strict=False, ) - std: tuple[float, float, float] = Field( + std: tuple[float, ...] = Field( default=( IMAGENET_NORMALIZE["std"][0], IMAGENET_NORMALIZE["std"][1], @@ -160,6 +164,7 @@ class MethodTransformArgs(PydanticConfig): # CLI. Setting strict to False allows Pydantic to convert lists to tuples. image_size: tuple[int, int] channel_drop: ChannelDropArgs | None + num_channels: int | Literal["auto"] random_resize: RandomResizeArgs | None random_flip: RandomFlipArgs | None random_rotation: RandomRotationArgs | None @@ -169,6 +174,58 @@ class MethodTransformArgs(PydanticConfig): gaussian_blur: GaussianBlurArgs | None solarize: SolarizeArgs | None + def resolve_auto(self) -> None: + if self.num_channels == "auto": + if self.channel_drop is not None: + self.num_channels = self.channel_drop.num_channels_keep + else: + self.num_channels = len(self.normalize.mean) + + def resolve_incompatible(self) -> None: + # Adjust normalization mean and std to match num_channels. + if len(self.normalize.mean) != no_auto(self.num_channels): + logger.debug( + "Adjusting mean of normalize transform to match num_channels. " + f"num_channels is {self.num_channels} but " + f"normalize.mean has length {len(self.normalize.mean)}." + ) + # Repeat the values until they match num_channels. + self.normalize.mean = tuple( + self.normalize.mean[i % len(self.normalize.mean)] + for i in range(no_auto(self.num_channels)) + ) + if len(self.normalize.std) != no_auto(self.num_channels): + logger.debug( + "Adjusting std of normalize transform to match num_channels. " + f"num_channels is {self.num_channels} but " + f"normalize.std has length {len(self.normalize.std)}." + ) + # Repeat the values until they match num_channels. + self.normalize.std = tuple( + self.normalize.std[i % len(self.normalize.std)] + for i in range(no_auto(self.num_channels)) + ) + + # Disable transforms if necessary. + if self.color_jitter is not None and no_auto(self.num_channels) != 3: + logger.debug( + "Disabling color jitter transform as it only supports 3-channel " + f"images but num_channels is {self.num_channels}." + ) + self.color_jitter = None + if self.random_gray_scale is not None and no_auto(self.num_channels) != 3: + logger.debug( + "Disabling random gray scale transform as it only supports 3-channel " + f"images but num_channels is {self.num_channels}." + ) + self.random_gray_scale = None + if self.solarize is not None and no_auto(self.num_channels) != 3: + logger.debug( + "Disabling solarize transform as it only supports 3-channel " + f"images but num_channels is {self.num_channels}." + ) + self.solarize = None + _T = TypeVar("_T", covariant=True) diff --git a/tests/_commands/test_common_helpers.py b/tests/_commands/test_common_helpers.py index c83e9683c..c87adcf41 100644 --- a/tests/_commands/test_common_helpers.py +++ b/tests/_commands/test_common_helpers.py @@ -858,6 +858,7 @@ def test_get_dataset__path(tmp_path: Path) -> None: _ = common_helpers.get_dataset( data=tmp_path, transform=ToTensorV2(), + num_channels=3, mmap_filepath=mmap_filepath, out_dir=tmp_path, ) @@ -868,6 +869,7 @@ def test_get_dataset__path__nonexisting(tmp_path: Path) -> None: common_helpers.get_dataset( data=tmp_path / "nonexisting", transform=ToTensorV2(), + num_channels=3, mmap_filepath=None, out_dir=tmp_path, ) @@ -880,6 +882,7 @@ def test_get_dataset__path__nondir(tmp_path: Path) -> None: common_helpers.get_dataset( data=file, transform=ToTensorV2(), + num_channels=3, mmap_filepath=None, out_dir=tmp_path, ) @@ -890,6 +893,7 @@ def test_get_dataset__path__empty(tmp_path: Path) -> None: common_helpers.get_dataset( data=tmp_path, transform=ToTensorV2(), + num_channels=3, mmap_filepath=None, out_dir=tmp_path, ) @@ -913,6 +917,7 @@ def test_get_dataset__dirs_and_files(tmp_path: Path) -> None: img_dir, ], transform=ToTensorV2(), + num_channels=3, mmap_filepath=mmap_filepath, out_dir=tmp_path, ) @@ -923,6 +928,7 @@ def test_get_dataset__dataset() -> None: dataset_1 = common_helpers.get_dataset( data=dataset, transform=ToTensorV2(), + num_channels=3, mmap_filepath=None, out_dir=Path("/tmp"), ) diff --git a/tests/_commands/test_train.py b/tests/_commands/test_train.py index b805e6f41..56890e4b4 100644 --- a/tests/_commands/test_train.py +++ b/tests/_commands/test_train.py @@ -14,6 +14,7 @@ import pytest import torch +from lightning_utilities.core.imports import RequirementCache from omegaconf import OmegaConf from pytest import LogCaptureFixture from pytest_mock import MockerFixture @@ -483,3 +484,34 @@ def test_train__checkpoint(mocker: MockerFixture, tmp_path: Path) -> None: # Skip the last layer as it is not pretrained. continue assert torch.equal(second_state_dict[key], exported_state_dict[key]) + + +@pytest.mark.parametrize( + "model, method, method_args", + [ + ("dinov2/_vittest14", "dinov2", {}), + ("timm/resnet18", "distillation", {"teacher": "dinov2/_vittest14"}), + ], +) +def test_train__multichannel( + tmp_path: Path, model: str, method: str, method_args: dict[str, Any] +) -> None: + if model.startswith("timm") and not RequirementCache("timm"): + pytest.skip("timm is not installed") + + out = tmp_path / "out" + data = tmp_path / "data" + helpers.create_images(image_dir=data, files=10, num_channels=4, mode="RGBA") + + train.train( + out=out, + data=data, + model=model, + method=method, + method_args=method_args, + batch_size=4, + num_workers=0, + epochs=1, + devices=1, + embed_dim=64, + ) diff --git a/tests/_commands/test_train_helpers.py b/tests/_commands/test_train_helpers.py index e9db0749c..ed3110d7d 100644 --- a/tests/_commands/test_train_helpers.py +++ b/tests/_commands/test_train_helpers.py @@ -74,18 +74,22 @@ def __len__(self) -> int: def test_get_transform__method() -> None: + transform_args = SimCLRTransformArgs() + transform_args.resolve_auto() assert isinstance( train_helpers.get_transform( - method="simclr", transform_args_resolved=SimCLRTransformArgs() + method="simclr", transform_args_resolved=transform_args ), SimCLRTransform, ) def test_get_transform__method_and_transform_dict() -> None: + transform_args = SimCLRTransformArgs(random_gray_scale=0.42) + transform_args.resolve_auto() transform = train_helpers.get_transform( method="simclr", - transform_args_resolved=SimCLRTransformArgs(random_gray_scale=0.42), + transform_args_resolved=transform_args, ) assert isinstance(transform, SimCLRTransform) assert transform.transform_args.random_gray_scale == 0.42 @@ -146,7 +150,9 @@ def test_get_embedding_model( if model_name.startswith("timm/"): pytest.importorskip("timm") x = torch.rand(1, 3, 224, 224) - model = package_helpers.get_wrapped_model(model_name, model_args=model_args) + model = package_helpers.get_wrapped_model( + model_name, model_args=model_args, num_input_channels=3 + ) embedding_model = train_helpers.get_embedding_model(model, embed_dim=embed_dim) embedding = embedding_model.forward(x) assert embedding.shape == (1, embedding_model.embed_dim, 1, 1) @@ -355,6 +361,7 @@ def test_get_method() -> None: optimizer_args=AdamWArgs(), embedding_model=embedding_model, global_batch_size=1, + num_input_channels=3, ) assert isinstance(method, SimCLR) assert method.method_args.temperature == 0.2 @@ -406,9 +413,9 @@ def test_get_epochs( "transform_dict, expected_result", [ # Test case for default empty dictionary - ({}, SimCLRTransformArgs()), + ({}, SimCLRTransformArgs(num_channels=3)), # Test case for None input - (None, SimCLRTransformArgs()), + (None, SimCLRTransformArgs(num_channels=3)), # Test case for user config ( { @@ -417,6 +424,7 @@ def test_get_epochs( "color_jitter": {"brightness": 0.1}, }, SimCLRTransformArgs( + num_channels=3, normalize=NormalizeArgs(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), random_rotation=RandomRotationArgs(prob=0.5, degrees=30), color_jitter=SimCLRColorJitterArgs(brightness=0.1), @@ -428,7 +436,8 @@ def test_get_epochs( normalize=NormalizeArgs(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ), SimCLRTransformArgs( - normalize=NormalizeArgs(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) + num_channels=3, + normalize=NormalizeArgs(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), ), ), ], diff --git a/tests/_commands/test_train_task.py b/tests/_commands/test_train_task.py index 52eec0225..e76d28417 100644 --- a/tests/_commands/test_train_task.py +++ b/tests/_commands/test_train_task.py @@ -5,7 +5,10 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # +from __future__ import annotations + from pathlib import Path +from typing import Any import pytest from lightning_utilities.core.imports import RequirementCache @@ -36,15 +39,27 @@ "OR on self-hosted CI with GPU (insufficient shared memory causes worker bus error)" ), ) -def test_train_semantic_segmentation(tmp_path: Path) -> None: +@pytest.mark.parametrize( + "model_name, model_args", + [ + # Reduce number of joint blocks _vittest14. + ("dinov2/_vittest14-eomt", {"num_joint_blocks": 1}), + ("dinov2/_vittest14-linear", {}), + ], +) +@pytest.mark.parametrize("num_channels", [3, 4]) +def test_train_semantic_segmentation( + tmp_path: Path, model_name: str, model_args: dict[str, Any], num_channels: int +) -> None: out = tmp_path / "out" train_images = tmp_path / "train_images" train_masks = tmp_path / "train_masks" val_images = tmp_path / "val_images" val_masks = tmp_path / "val_masks" - helpers.create_images(train_images) + mode = "RGB" if num_channels == 3 else "RGBA" + helpers.create_images(train_images, num_channels=num_channels, mode=mode) helpers.create_masks(train_masks) - helpers.create_images(val_images) + helpers.create_images(val_images, num_channels=num_channels, mode=mode) helpers.create_masks(val_masks) lightly_train.train_semantic_segmentation( @@ -63,10 +78,8 @@ def test_train_semantic_segmentation(tmp_path: Path) -> None: 1: "car", }, }, - model="dinov2/_vittest14-eomt", - model_args={ - "num_joint_blocks": 1, # Reduce joint blocks for _vittest14 - }, + model=model_name, + model_args=model_args, # The operator 'aten::upsample_bicubic2d.out' raises a NotImplementedError # on macOS with MPS backend. accelerator="auto" if not sys.platform.startswith("darwin") else "cpu", @@ -74,6 +87,9 @@ def test_train_semantic_segmentation(tmp_path: Path) -> None: batch_size=2, num_workers=2, steps=2, + transform_args={ + "num_channels": num_channels, + }, ) assert out.exists() assert out.is_dir() @@ -83,7 +99,7 @@ def test_train_semantic_segmentation(tmp_path: Path) -> None: checkpoint=out / "checkpoints" / "last.ckpt" ) # Check forward pass - dummy_input = torch.randn(1, 3, 224, 224) + dummy_input = torch.randn(1, num_channels, 224, 224) prediction = model.predict(dummy_input[0]) assert prediction.shape == (224, 224) assert prediction.min() >= 0 diff --git a/tests/_data/test_image_dataset.py b/tests/_data/test_image_dataset.py index 24604f86a..f34f51d79 100644 --- a/tests/_data/test_image_dataset.py +++ b/tests/_data/test_image_dataset.py @@ -112,12 +112,22 @@ def nested_image_dir(tmp_path: Path) -> Path: class TestImageDataset: - def test___getitem__(self, flat_image_dir: Path) -> None: - filenames = [ImageFilename("image1.jpg"), ImageFilename("image2.jpg")] + @pytest.mark.parametrize("num_channels", [3, 4]) + def test___getitem__(self, tmp_path: Path, num_channels: int) -> None: + helpers.create_images( + tmp_path, + files=["image1.png", "image2.png"], + height=32, + width=32, + num_channels=num_channels, + mode="RGB" if num_channels == 3 else "RGBA", + ) + filenames = [ImageFilename("image1.png"), ImageFilename("image2.png")] dataset = ImageDataset( - image_dir=flat_image_dir, + image_dir=tmp_path, image_filenames=filenames, transform=DummyMethodTransform(), + num_channels=num_channels, ) assert len(dataset) == 2 for i in range(2): @@ -127,7 +137,7 @@ def test___getitem__(self, flat_image_dir: Path) -> None: assert isinstance(item["views"], list) assert len(item["views"]) == 1 assert isinstance(item["views"][0], Tensor) - assert item["views"][0].shape == (3, 32, 32) + assert item["views"][0].shape == (num_channels, 32, 32) assert "mask" not in item @pytest.mark.parametrize( @@ -156,6 +166,7 @@ def test___getitem____mode(self, tmp_path: Path, mode: str, extension: str) -> N image_dir=image_dir, image_filenames=filenames, transform=DummyMethodTransform(), + num_channels=3, ) image = dataset[0]["views"][0] assert isinstance(image, Tensor) @@ -177,6 +188,7 @@ def test___getitem____truncated(self, tmp_path: Path) -> None: image_dir=image_dir, image_filenames=filenames, transform=DummyMethodTransform(), + num_channels=3, ) image = dataset[0]["views"][0] assert isinstance(image, Tensor) @@ -196,6 +208,7 @@ def test___getitem____masks(self, tmp_path: Path) -> None: image_filenames=img_filenames, mask_dir=mask_dir, transform=DummyMethodTransform(), + num_channels=3, ) item: DatasetItem = dataset[0] print(f"{item=}") @@ -217,6 +230,7 @@ def test_dataloader(self, flat_image_dir: Path) -> None: image_dir=flat_image_dir, image_filenames=filenames, transform=DummyMethodTransform(), + num_channels=3, ) assert len(dataset) == 2 dataloader = DataLoader( diff --git a/tests/_data/test_mask_semantic_segmentation_dataset.py b/tests/_data/test_mask_semantic_segmentation_dataset.py index eb848d8c8..d04fc5b0a 100644 --- a/tests/_data/test_mask_semantic_segmentation_dataset.py +++ b/tests/_data/test_mask_semantic_segmentation_dataset.py @@ -10,7 +10,6 @@ from pathlib import Path from typing import Any -import albumentations as A import pytest import torch from torch import Tensor @@ -22,32 +21,34 @@ MaskSemanticSegmentationDatasetArgs, SplitArgs, ) -from lightly_train._transforms.task_transform import ( - TaskTransform, - TaskTransformArgs, - TaskTransformInput, - TaskTransformOutput, +from lightly_train._transforms.semantic_segmentation_transform import ( + SemanticSegmentationTransform, + SemanticSegmentationTransformArgs, +) +from lightly_train._transforms.transform import ( + NormalizeArgs, + SmallestMaxSizeArgs, ) from .. import helpers -class DummyTransform(TaskTransform): - transform_args_cls = TaskTransformArgs - - def __init__(self, transform_args: TaskTransformArgs): - super().__init__(transform_args=transform_args) - self.transform = A.Compose( - [ - A.Resize(32, 32), - A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), - A.pytorch.transforms.ToTensorV2(), - ] - ) - - def __call__(self, input: TaskTransformInput) -> TaskTransformOutput: - output: TaskTransformOutput = self.transform(**input) - return output +def _dummy_transform(num_channels: int = 3) -> SemanticSegmentationTransform: + args = SemanticSegmentationTransformArgs( + ignore_index=-100, + image_size=(32, 32), + channel_drop=None, + num_channels=num_channels, + normalize=NormalizeArgs(), + random_flip=None, + color_jitter=None, + scale_jitter=None, + smallest_max_size=SmallestMaxSizeArgs(max_size=32, prob=1.0), + random_crop=None, + ) + args.resolve_auto() + args.resolve_incompatible() + return SemanticSegmentationTransform(args) class TestMaskSemanticSegmentationDataArgs: @@ -216,25 +217,32 @@ def test_included_classes(self, tmp_path: Path) -> None: class TestMaskSemanticSegmentationDataset: @pytest.mark.parametrize( - "num_classes, expected_mask_dtype, ignore_index", + "num_classes, num_channels, expected_mask_dtype, ignore_index", [ - (5, torch.long, -100), - (500, torch.long, -100), + (5, 3, torch.long, -100), + (5, 4, torch.long, -100), + (500, 3, torch.long, -100), ], ) def test__getitem__( self, num_classes: int, + num_channels: int, expected_mask_dtype: torch.dtype, tmp_path: Path, ignore_index: int, ) -> None: image_dir = tmp_path / "images" mask_dir = tmp_path / "masks" - image_filenames = ["image0.jpg", "image1.jpg"] + image_filenames = ["image0.png", "image1.png"] mask_filenames = ["image0.png", "image1.png"] - helpers.create_images(image_dir, files=image_filenames) + helpers.create_images( + image_dir, + files=image_filenames, + num_channels=num_channels, + mode="RGB" if num_channels == 3 else "RGBA", + ) helpers.create_masks(mask_dir, files=mask_filenames, num_classes=num_classes) dataset_args = MaskSemanticSegmentationDatasetArgs( @@ -245,7 +253,7 @@ def test__getitem__( }, ignore_index=ignore_index, ) - transform = DummyTransform(transform_args=TaskTransformArgs()) + transform = _dummy_transform(num_channels=num_channels) dataset = MaskSemanticSegmentationDataset( dataset_args=dataset_args, image_info=list(dataset_args.list_image_info()), @@ -255,7 +263,7 @@ def test__getitem__( assert len(dataset) == 2 for item in dataset: # type: ignore[attr-defined] assert isinstance(item["image"], Tensor) - assert item["image"].shape == (3, 32, 32) + assert item["image"].shape == (num_channels, 32, 32) assert item["image"].dtype == torch.float32 assert isinstance(item["mask"], Tensor) assert item["mask"].shape == (32, 32) @@ -273,8 +281,8 @@ def test__getitem__( ignored_pixels = mask == ignore_index assert (ignored_pixels.sum() + valid_pixels.sum()) == mask.numel() assert sorted(item["image_path"] for item in dataset) == [ # type: ignore[attr-defined] - str(image_dir / "image0.jpg"), - str(image_dir / "image1.jpg"), + str(image_dir / "image0.png"), + str(image_dir / "image1.png"), ] def test_get_class_mapping(self, tmp_path: Path) -> None: @@ -298,7 +306,7 @@ def test_get_class_mapping(self, tmp_path: Path) -> None: classes=classes, ignore_index=-100, ) - transform = DummyTransform(transform_args=TaskTransformArgs()) + transform = _dummy_transform() dataset = MaskSemanticSegmentationDataset( dataset_args=dataset_args, image_info=list(dataset_args.list_image_info()), @@ -331,7 +339,7 @@ def test_get_class_mapping__ignore_classes(self, tmp_path: Path) -> None: ignore_classes=ignore_classes, ignore_index=-100, ) - transform = DummyTransform(transform_args=TaskTransformArgs()) + transform = _dummy_transform() dataset = MaskSemanticSegmentationDataset( dataset_args=dataset_args, image_info=list(dataset_args.list_image_info()), diff --git a/tests/_methods/detcon/test_detcon.py b/tests/_methods/detcon/test_detcon.py index 2f2348b12..2b1056686 100644 --- a/tests/_methods/detcon/test_detcon.py +++ b/tests/_methods/detcon/test_detcon.py @@ -136,6 +136,7 @@ def test_training_step_impl(self) -> None: optimizer_args=DetConSSGDArgs(), embedding_model=emb_model, global_batch_size=b, + num_input_channels=3, ) out = detcons.training_step_impl(batch, 0) @@ -164,6 +165,7 @@ def test_training_step_impl(self) -> None: optimizer_args=DetConBSGDArgs(), embedding_model=emb_model, global_batch_size=b, + num_input_channels=3, ) out = detconb.training_step_impl(batch, 0) assert out.loss.shape == Size([]) diff --git a/tests/_methods/dinov2/test_dinov2.py b/tests/_methods/dinov2/test_dinov2.py index b7a227963..c11d38d96 100644 --- a/tests/_methods/dinov2/test_dinov2.py +++ b/tests/_methods/dinov2/test_dinov2.py @@ -48,6 +48,7 @@ def setup_dinov2_helper( optimizer_args=optimizer_args, embedding_model=emb_model, global_batch_size=batch_size, + num_input_channels=3, ) trainer_mock = mocker.Mock() diff --git a/tests/_methods/distillation/test_distillation.py b/tests/_methods/distillation/test_distillation.py index 10f43c1f0..b09cbe361 100644 --- a/tests/_methods/distillation/test_distillation.py +++ b/tests/_methods/distillation/test_distillation.py @@ -202,6 +202,7 @@ def test_queue_update(self, mocker: MockerFixture) -> None: optimizer_args=DistillationLARSArgs(), embedding_model=EmbeddingModel(wrapped_model=DummyCustomModel()), global_batch_size=batch_size, + num_input_channels=3, ) mock_get_teacher_model.assert_called_once() @@ -253,6 +254,7 @@ def test_teacher_queue_never_exceeds_capacity(self, mocker: MockerFixture) -> No optimizer_args=DistillationLARSArgs(), embedding_model=EmbeddingModel(wrapped_model=DummyCustomModel()), global_batch_size=batch_size, + num_input_channels=3, ) mock_get_teacher_model.assert_called_once() @@ -318,6 +320,7 @@ def test_load_state_dict_from_pretrained_teacher( optimizer_args=DistillationLARSArgs(), embedding_model=student_model, global_batch_size=batch_size, + num_input_channels=3, ) def test_load_state_dict_ignores_missing_teacher_keys( @@ -353,6 +356,7 @@ def test_load_state_dict_ignores_missing_teacher_keys( optimizer_args=DistillationLARSArgs(), embedding_model=student_model, global_batch_size=batch_size, + num_input_channels=3, ) mock_get_teacher.assert_called_once() @@ -403,6 +407,7 @@ def test_load_state_dict_raises_on_non_teacher_missing_key( optimizer_args=DistillationLARSArgs(), embedding_model=student_model, global_batch_size=batch_size, + num_input_channels=3, ) mock_get_teacher.assert_called_once() @@ -454,6 +459,7 @@ def test_teacher_not_saved_in_checkpoint(self, mocker: MockerFixture) -> None: optimizer_args=DistillationLARSArgs(), embedding_model=student_model, global_batch_size=batch_size, + num_input_channels=3, ) mock_get_teacher.assert_called_once() @@ -529,6 +535,7 @@ def test_distillation_configure_optimizers_lr_scaling( optimizer_args=DistillationLARSArgs(lr=base_lr), embedding_model=student_model, global_batch_size=global_batch_size, + num_input_channels=3, ) # Mock trainer attributes needed by configure_optimizers. diff --git a/tests/_methods/distillationv2/test_distillationv2.py b/tests/_methods/distillationv2/test_distillationv2.py index cfc173c6f..32017da7d 100644 --- a/tests/_methods/distillationv2/test_distillationv2.py +++ b/tests/_methods/distillationv2/test_distillationv2.py @@ -135,6 +135,7 @@ def test_forward_student_output_shape(self, mocker: MockerFixture) -> None: optimizer_args=DistillationV2LARSArgs(), embedding_model=mock_student_model, global_batch_size=batch_size, + num_input_channels=3, ) mock_get_teacher.assert_called_once() @@ -184,6 +185,7 @@ def test_load_state_dict_from_pretrained_teacher( optimizer_args=DistillationV2LARSArgs(), embedding_model=student_model, global_batch_size=batch_size, + num_input_channels=3, ) def test_load_state_dict_ignores_missing_teacher_keys( @@ -218,6 +220,7 @@ def test_load_state_dict_ignores_missing_teacher_keys( optimizer_args=DistillationV2LARSArgs(), embedding_model=student_model, global_batch_size=batch_size, + num_input_channels=3, ) mock_get_teacher.assert_called_once() @@ -267,6 +270,7 @@ def test_load_state_dict_raises_on_non_teacher_missing_key( optimizer_args=DistillationV2LARSArgs(), embedding_model=student_model, global_batch_size=batch_size, + num_input_channels=3, ) mock_get_teacher.assert_called_once() @@ -317,6 +321,7 @@ def test_teacher_not_saved_in_checkpoint(self, mocker: MockerFixture) -> None: optimizer_args=DistillationV2LARSArgs(), embedding_model=student_model, global_batch_size=batch_size, + num_input_channels=3, ) mock_get_teacher.assert_called_once() @@ -391,6 +396,7 @@ def test_distillation_configure_optimizers_lr_scaling( optimizer_args=DistillationV2LARSArgs(lr=base_lr), embedding_model=student_model, global_batch_size=global_batch_size, + num_input_channels=3, ) # Mock trainer attributes needed by configure_optimizers. diff --git a/tests/_models/test_package_helpers.py b/tests/_models/test_package_helpers.py index 276f50a1e..ab3940271 100644 --- a/tests/_models/test_package_helpers.py +++ b/tests/_models/test_package_helpers.py @@ -49,20 +49,27 @@ def test_get_model__rfdetr() -> None: pytest.importorskip("rfdetr") from rfdetr.detr import RFDETR - model = package_helpers.get_wrapped_model("rfdetr/rf-detr-base") + model = package_helpers.get_wrapped_model( + "rfdetr/rf-detr-base", num_input_channels=3 + ) assert isinstance(model.get_model(), RFDETR) def test_get_model__torchvision() -> None: - model = package_helpers.get_wrapped_model("torchvision/resnet18") + model = package_helpers.get_wrapped_model( + "torchvision/resnet18", num_input_channels=3 + ) assert isinstance(model.get_model(), ResNet) -def test_get_model__timm() -> None: +@pytest.mark.parametrize("num_input_channels", [3, 4]) +def test_get_model__timm(num_input_channels: int) -> None: pytest.importorskip("timm") from timm.models.resnet import ResNet - model = package_helpers.get_wrapped_model("timm/resnet18") + model = package_helpers.get_wrapped_model( + "timm/resnet18", num_input_channels=num_input_channels + ) assert isinstance(model.get_model(), ResNet) @@ -72,7 +79,9 @@ def test_get_model__super_gradients() -> None: YoloNAS_S, ) - model = package_helpers.get_wrapped_model("super_gradients/yolo_nas_s") + model = package_helpers.get_wrapped_model( + "super_gradients/yolo_nas_s", num_input_channels=3 + ) assert isinstance(model.get_model(), YoloNAS_S) @@ -80,13 +89,17 @@ def test_get_model__ultralytics() -> None: pytest.importorskip("ultralytics") from ultralytics import YOLO - model = package_helpers.get_wrapped_model("ultralytics/yolov8s.yaml") + model = package_helpers.get_wrapped_model( + "ultralytics/yolov8s.yaml", num_input_channels=3 + ) assert isinstance(model.get_model(), YOLO) def test_get_model_wrapper__timm() -> None: pytest.importorskip("timm") - wrapped_model = package_helpers.get_wrapped_model("timm/resnet18") + wrapped_model = package_helpers.get_wrapped_model( + "timm/resnet18", num_input_channels=3 + ) model = wrapped_model.get_model() x = torch.rand(1, 3, 64, 64) diff --git a/tests/helpers.py b/tests/helpers.py index 545ca4d1d..ef003af4a 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -59,12 +59,14 @@ def __init__( optimizer_args: OptimizerArgs, embedding_model: EmbeddingModel, global_batch_size: int, + num_input_channels: int = 3, ): super().__init__( method_args=method_args, optimizer_args=optimizer_args, embedding_model=embedding_model, global_batch_size=global_batch_size, + num_input_channels=num_input_channels, ) self.embedding_model = embedding_model self.method_args = method_args @@ -114,13 +116,16 @@ def get_method(wrapped_model: ModelWrapper) -> Method: optimizer_args=AdamWArgs(), embedding_model=EmbeddingModel(wrapped_model=wrapped_model), global_batch_size=2, + num_input_channels=3, ) def get_method_dinov2() -> DINOv2: optim_args = DINOv2AdamWViTArgs() dinov2_args = DINOv2Args() - wrapped_model = package_helpers.get_wrapped_model(model="dinov2/_vittest14") + wrapped_model = package_helpers.get_wrapped_model( + model="dinov2/_vittest14", num_input_channels=3 + ) dinov2_args.resolve_auto( scaling_info=ScalingInfo(dataset_size=1000, epochs=100), optimizer_args=optim_args, @@ -131,6 +136,7 @@ def get_method_dinov2() -> DINOv2: optimizer_args=optim_args, embedding_model=EmbeddingModel(wrapped_model=wrapped_model), global_batch_size=2, + num_input_channels=3, ) return dinov2