Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightly_train/_commands/common_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,7 @@ def get_dataset_mmap_filenames(
def get_dataset(
data: PathLike | Sequence[PathLike] | Dataset[DatasetItem],
transform: Transform,
num_channels: int,
mmap_filepath: Path | None,
out_dir: Path,
) -> Dataset[DatasetItem]:
Expand Down Expand Up @@ -530,6 +531,7 @@ def get_dataset(
mmap_filepath=mmap_filepath,
),
transform=transform,
num_channels=num_channels,
mask_dir=Path(mask_dir) if mask_dir is not None else None,
)

Expand All @@ -556,6 +558,7 @@ def get_dataset(
mmap_filepath=mmap_filepath,
),
transform=transform,
num_channels=num_channels,
)
else:
raise ValueError(
Expand Down
4 changes: 3 additions & 1 deletion src/lightly_train/_commands/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,10 @@ def embed_from_config(config: EmbedConfig) -> None:
checkpoint_path = common_helpers.get_checkpoint_path(checkpoint=config.checkpoint)
writer = writer_helpers.get_writer(format=format, filepath=out_path)
checkpoint_instance = _get_checkpoint(checkpoint=checkpoint_path)
normalize_args = checkpoint_instance.lightly_train.normalize_args
transform = _get_transform(
image_size=config.image_size,
normalize_args=checkpoint_instance.lightly_train.normalize_args,
normalize_args=normalize_args,
)
num_workers = common_helpers.get_num_workers(
num_workers=config.num_workers, num_devices_per_node=1
Expand All @@ -134,6 +135,7 @@ def embed_from_config(config: EmbedConfig) -> None:
dataset = common_helpers.get_dataset(
data=config.data,
transform=transform,
num_channels=len(normalize_args.mean),
mmap_filepath=mmap_filepath,
out_dir=out_path,
)
Expand Down
7 changes: 6 additions & 1 deletion src/lightly_train/_commands/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from lightly_train._commands.common_helpers import ModelFormat
from lightly_train._configs import omegaconf_utils, validate
from lightly_train._configs.config import PydanticConfig
from lightly_train._configs.validate import no_auto
from lightly_train._loggers import logger_helpers
from lightly_train._loggers.logger_args import LoggerArgs
from lightly_train._methods import method_helpers
Expand Down Expand Up @@ -294,6 +295,7 @@ def train_from_config(config: TrainConfig) -> None:
dataset = common_helpers.get_dataset(
data=config.data,
transform=transform_instance,
num_channels=no_auto(transform_instance.transform_args.num_channels),
mmap_filepath=mmap_filepath,
out_dir=out_dir,
)
Expand All @@ -309,7 +311,9 @@ def train_from_config(config: TrainConfig) -> None:
epochs=config.epochs,
)
wrapped_model = package_helpers.get_wrapped_model(
model=config.model, model_args=config.model_args
model=config.model,
model_args=config.model_args,
num_input_channels=no_auto(transform_instance.transform_args.num_channels),
)
embedding_model = train_helpers.get_embedding_model(
wrapped_model=wrapped_model, embed_dim=config.embed_dim
Expand Down Expand Up @@ -397,6 +401,7 @@ def train_from_config(config: TrainConfig) -> None:
optimizer_args=config.optim_args,
embedding_model=embedding_model,
global_batch_size=config.batch_size,
num_input_channels=no_auto(transform_instance.transform_args.num_channels),
)
train_helpers.load_checkpoint(
checkpoint=config.checkpoint,
Expand Down
12 changes: 10 additions & 2 deletions src/lightly_train/_commands/train_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,15 @@ def get_transform_args(
if transform_args is None:
# We need to typeignore here because a MethodTransformArgs might not have
# defaults for all fields, while its children do.
return transform_args_cls() # type: ignore[call-arg]
transform_args = transform_args_cls() # type: ignore[call-arg]
else:
transform_args = validate.pydantic_model_validate(
transform_args_cls, transform_args
)

return validate.pydantic_model_validate(transform_args_cls, transform_args)
transform_args.resolve_auto()
transform_args.resolve_incompatible()
return transform_args


def get_transform(
Expand Down Expand Up @@ -326,13 +332,15 @@ def get_method(
optimizer_args: OptimizerArgs,
embedding_model: EmbeddingModel,
global_batch_size: int,
num_input_channels: int,
) -> Method:
logger.debug(f"Getting method for '{method_cls.__name__}'")
return method_cls(
method_args=method_args,
optimizer_args=optimizer_args,
embedding_model=embedding_model,
global_batch_size=global_batch_size,
num_input_channels=num_input_channels,
)


Expand Down
14 changes: 13 additions & 1 deletion src/lightly_train/_commands/train_task_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@
TrainModelArgs,
)
from lightly_train._train_task_state import TrainTaskState
from lightly_train._transforms.semantic_segmentation_transform import (
SemanticSegmentationTransform,
)
from lightly_train._transforms.task_transform import (
TaskTransform,
TaskTransformArgs,
Expand Down Expand Up @@ -226,17 +229,24 @@ def get_transform_args(
train_transform_args_cls, transform_args
)
train_transform_args.resolve_auto()
train_transform_args.resolve_incompatible()

# Take defaults from train transform.
val_args_dict = train_transform_args.model_dump(
include={"image_size": True, "normalize": True, "ignore_index": True}
include={
"image_size": True,
"normalize": True,
"ignore_index": True,
"num_channels": True,
}
)
# Overwrite with user provided val args.
val_args_dict.update(val_args)
val_transform_args = validate.pydantic_model_validate(
val_transform_args_cls, val_args_dict
)
val_transform_args.resolve_auto()
val_transform_args.resolve_incompatible()

logger.debug(
f"Resolved train transform args {pretty_format_args(train_transform_args.model_dump())}"
Expand Down Expand Up @@ -409,6 +419,8 @@ def get_dataset(
) -> MaskSemanticSegmentationDataset:
filenames = list(dataset_args.list_image_filenames())
dataset_cls = dataset_args.get_dataset_cls()
# TODO(Guarin, 08/25): Relax this when we add object detection.
assert isinstance(transform, SemanticSegmentationTransform)
return dataset_cls(
dataset_args=dataset_args,
image_filenames=get_dataset_mmap_filenames(
Expand Down
21 changes: 16 additions & 5 deletions src/lightly_train/_data/image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,31 @@ def __init__(
image_dir: Path | None,
image_filenames: Sequence[ImageFilename],
transform: Transform,
num_channels: int,
mask_dir: Path | None = None,
):
self.image_dir = image_dir
self.image_filenames = image_filenames
self.mask_dir = mask_dir
self.transform = transform
self.num_channels = num_channels

try:
self.image_mode = ImageMode(Env.LIGHTLY_TRAIN_IMAGE_MODE.value)
except ValueError:
image_mode = (
None
if Env.LIGHTLY_TRAIN_IMAGE_MODE.value is None
else ImageMode(Env.LIGHTLY_TRAIN_IMAGE_MODE.value)
)
if image_mode is None:
image_mode = (
ImageMode.RGB if self.num_channels == 3 else ImageMode.UNCHANGED
)

if image_mode not in (ImageMode.RGB, ImageMode.UNCHANGED):
raise ValueError(
f'Invalid image mode: {Env.LIGHTLY_TRAIN_IMAGE_MODE.name}="{Env.LIGHTLY_TRAIN_IMAGE_MODE.value}". '
"Supported modes are 'RGB' and 'UNCHANGED'."
f"Invalid image mode: '{image_mode}'. "
f"Supported modes are '{[ImageMode.RGB.value, ImageMode.UNCHANGED.value]}'."
)
self.image_mode = image_mode

def __getitem__(self, idx: int) -> DatasetItem:
filename = self.image_filenames[idx]
Expand Down
38 changes: 24 additions & 14 deletions src/lightly_train/_data/mask_semantic_segmentation_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
from lightly_train._data.file_helpers import ImageMode
from lightly_train._data.task_data_args import TaskDataArgs
from lightly_train._env import Env
from lightly_train._transforms.task_transform import TaskTransform
from lightly_train._transforms.semantic_segmentation_transform import (
SemanticSegmentationTransform,
SemanticSegmentationTransformArgs,
)
from lightly_train.types import (
BinaryMasksDict,
ImageFilename,
Expand All @@ -40,7 +43,7 @@ def __init__(
self,
dataset_args: MaskSemanticSegmentationDatasetArgs,
image_filenames: Sequence[ImageFilename],
transform: TaskTransform,
transform: SemanticSegmentationTransform,
):
self.args = dataset_args
self.image_filenames = image_filenames
Expand All @@ -51,20 +54,27 @@ def __init__(
self.class_mapping = self.get_class_mapping()
self.valid_classes = torch.tensor(list(self.class_mapping.keys()))

image_mode = Env.LIGHTLY_TRAIN_IMAGE_MODE.value
if image_mode not in ("RGB", "UNCHANGED"):
transform_args = transform.transform_args
assert isinstance(transform_args, SemanticSegmentationTransformArgs)

image_mode = (
None
if Env.LIGHTLY_TRAIN_IMAGE_MODE.value is None
else ImageMode(Env.LIGHTLY_TRAIN_IMAGE_MODE.value)
)
if image_mode is None:
image_mode = (
ImageMode.RGB
if transform_args.num_channels == 3
else ImageMode.UNCHANGED
)

if image_mode not in (ImageMode.RGB, ImageMode.UNCHANGED):
raise ValueError(
f'Invalid image mode: {Env.LIGHTLY_TRAIN_IMAGE_MODE.name}="{image_mode}". '
"Supported modes are 'RGB' and 'UNCHANGED'."
f"Invalid image mode: '{image_mode}'. "
f"Supported modes are '{[ImageMode.RGB.value, ImageMode.UNCHANGED.value]}'."
)
# Convert string to enum value
if image_mode == "RGB":
self.image_mode = ImageMode.RGB
elif image_mode == "UNCHANGED":
self.image_mode = ImageMode.UNCHANGED
else:
# This should not happen due to the check above, but added for type safety
raise ValueError(f"Unexpected image mode: {image_mode}")
self.image_mode = image_mode

def is_mask_valid(self, mask: Tensor) -> bool:
# Check if at least one value in the mask is in the valid classes.
Expand Down
4 changes: 2 additions & 2 deletions src/lightly_train/_embedding/embedding_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ class EmbeddingTransform:
def __init__(
self,
image_size: int | tuple[int, int],
mean: tuple[float, float, float],
std: tuple[float, float, float],
mean: tuple[float, ...],
std: tuple[float, ...],
):
if isinstance(image_size, int):
image_size = (image_size, image_size)
Expand Down
4 changes: 2 additions & 2 deletions src/lightly_train/_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,9 @@ class Env:
)
# Mode in which images are loaded. This can be "RGB" to load images in RGB or
# "UNCHANGED" to load images in their original format without any conversion.
LIGHTLY_TRAIN_IMAGE_MODE: EnvVar[str] = EnvVar(
LIGHTLY_TRAIN_IMAGE_MODE: EnvVar[str | None] = EnvVar(
name="LIGHTLY_TRAIN_IMAGE_MODE",
_default="RGB",
_default=None,
_type=str,
)
LIGHTLY_TRAIN_MASK_DIR: EnvVar[Path | None] = EnvVar(
Expand Down
2 changes: 2 additions & 0 deletions src/lightly_train/_methods/densecl/densecl.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,14 @@ def __init__(
optimizer_args: OptimizerArgs,
embedding_model: EmbeddingModel,
global_batch_size: int,
num_input_channels: int,
):
super().__init__(
method_args=method_args,
optimizer_args=optimizer_args,
embedding_model=embedding_model,
global_batch_size=global_batch_size,
num_input_channels=num_input_channels,
)
self.method_args = method_args
self.query_encoder = DenseCLEncoder(
Expand Down
2 changes: 2 additions & 0 deletions src/lightly_train/_methods/densecl/densecl_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from __future__ import annotations

from pydantic import Field
from typing_extensions import Literal

from lightly_train._transforms.transform import (
ChannelDropArgs,
Expand Down Expand Up @@ -55,6 +56,7 @@ class DenseCLGaussianBlurArgs(GaussianBlurArgs):
class DenseCLTransformArgs(MethodTransformArgs):
image_size: tuple[int, int] = Field(default=(224, 224), strict=False)
channel_drop: ChannelDropArgs | None = None
num_channels: int | Literal["auto"] = "auto"
random_resize: DenseCLRandomResizeArgs | None = Field(
default_factory=DenseCLRandomResizeArgs
)
Expand Down
4 changes: 4 additions & 0 deletions src/lightly_train/_methods/detcon/detcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,14 @@ def __init__(
optimizer_args: OptimizerArgs,
embedding_model: EmbeddingModel,
global_batch_size: int,
num_input_channels: int,
) -> None:
super().__init__(
method_args=method_args,
optimizer_args=optimizer_args,
embedding_model=embedding_model,
global_batch_size=global_batch_size,
num_input_channels=num_input_channels,
)
self.method_args = method_args
self.embedding_model = embedding_model
Expand Down Expand Up @@ -268,12 +270,14 @@ def __init__(
optimizer_args: OptimizerArgs,
embedding_model: EmbeddingModel,
global_batch_size: int,
num_input_channels: int,
) -> None:
super().__init__(
method_args=method_args,
optimizer_args=optimizer_args,
embedding_model=embedding_model,
global_batch_size=global_batch_size,
num_input_channels=num_input_channels,
)
self.method_args = method_args

Expand Down
4 changes: 4 additions & 0 deletions src/lightly_train/_methods/detcon/detcon_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#
from __future__ import annotations

from typing import Literal

from pydantic import Field

from lightly_train._configs.config import PydanticConfig
Expand Down Expand Up @@ -58,6 +60,7 @@ class DetConSView1TransformArgs(PydanticConfig):
class DetConSTransformArgs(MethodTransformArgs):
image_size: tuple[int, int] = Field(default=(224, 224), strict=False)
channel_drop: ChannelDropArgs | None = None
num_channels: int | Literal["auto"] = "auto"
random_resize: RandomResizeArgs | None = Field(default_factory=RandomResizeArgs)
random_flip: RandomFlipArgs | None = Field(default_factory=RandomFlipArgs)
random_rotation: RandomRotationArgs | None = None
Expand Down Expand Up @@ -109,6 +112,7 @@ class DetConBView1TransformArgs(PydanticConfig):
class DetConBTransformArgs(MethodTransformArgs):
image_size: tuple[int, int] = Field(default=(224, 224), strict=False)
channel_drop: ChannelDropArgs | None = None
num_channels: int | Literal["auto"] = "auto"
random_resize: RandomResizeArgs | None = Field(default_factory=RandomResizeArgs)
random_flip: RandomFlipArgs | None = Field(default_factory=RandomFlipArgs)
random_rotation: RandomRotationArgs | None = None
Expand Down
2 changes: 2 additions & 0 deletions src/lightly_train/_methods/dino/dino.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,14 @@ def __init__(
optimizer_args: OptimizerArgs,
embedding_model: EmbeddingModel,
global_batch_size: int,
num_input_channels: int,
):
super().__init__(
method_args=method_args,
optimizer_args=optimizer_args,
embedding_model=embedding_model,
global_batch_size=global_batch_size,
num_input_channels=num_input_channels,
)
self.method_args = method_args
self.teacher_embedding_model = embedding_model
Expand Down
3 changes: 3 additions & 0 deletions src/lightly_train/_methods/dino/dino_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#
from __future__ import annotations

from typing import Literal

from pydantic import Field

from lightly_train._configs.config import PydanticConfig
Expand Down Expand Up @@ -99,6 +101,7 @@ class DINOTransformArgs(MethodTransformArgs):
# https://github.com/facebookresearch/dino#resnet-50-and-other-convnets-trainings
image_size: tuple[int, int] = Field(default=(224, 224), strict=False)
channel_drop: ChannelDropArgs | None = None
num_channels: int | Literal["auto"] = "auto"
random_resize: DINORandomResizeArgs | None = Field(
default_factory=DINORandomResizeArgs
)
Expand Down
Loading