Skip to content
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightly_train/_commands/common_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,7 @@ def get_dataset_mmap_filenames(
def get_dataset(
data: PathLike | Sequence[PathLike] | Dataset[DatasetItem],
transform: Transform,
num_channels: int,
mmap_filepath: Path | None,
out_dir: Path,
) -> Dataset[DatasetItem]:
Expand Down Expand Up @@ -583,6 +584,7 @@ def get_dataset(
mmap_filepath=mmap_filepath,
),
transform=transform,
num_channels=num_channels,
mask_dir=Path(mask_dir) if mask_dir is not None else None,
)

Expand All @@ -609,6 +611,7 @@ def get_dataset(
mmap_filepath=mmap_filepath,
),
transform=transform,
num_channels=num_channels,
)
else:
raise ValueError(
Expand Down
1 change: 1 addition & 0 deletions src/lightly_train/_commands/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def embed_from_config(config: EmbedConfig) -> None:
dataset = common_helpers.get_dataset(
data=config.data,
transform=transform,
num_channels=len(checkpoint_instance.lightly_train.normalize_args.mean),
mmap_filepath=mmap_filepath,
out_dir=out_path,
)
Expand Down
7 changes: 6 additions & 1 deletion src/lightly_train/_commands/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from lightly_train._commands.common_helpers import ModelFormat
from lightly_train._configs import omegaconf_utils, validate
from lightly_train._configs.config import PydanticConfig
from lightly_train._configs.validate import no_auto
from lightly_train._loggers import logger_helpers
from lightly_train._loggers.logger_args import LoggerArgs
from lightly_train._methods import method_helpers
Expand Down Expand Up @@ -295,6 +296,7 @@ def train_from_config(config: TrainConfig) -> None:
dataset = common_helpers.get_dataset(
data=config.data,
transform=transform_instance,
num_channels=no_auto(transform_instance.transform_args.num_channels),
mmap_filepath=mmap_filepath,
out_dir=out_dir,
)
Expand All @@ -310,7 +312,9 @@ def train_from_config(config: TrainConfig) -> None:
epochs=config.epochs,
)
wrapped_model = package_helpers.get_wrapped_model(
model=config.model, model_args=config.model_args
model=config.model,
model_args=config.model_args,
num_input_channels=no_auto(transform_instance.transform_args.num_channels),
)
embedding_model = train_helpers.get_embedding_model(
wrapped_model=wrapped_model, embed_dim=config.embed_dim
Expand Down Expand Up @@ -398,6 +402,7 @@ def train_from_config(config: TrainConfig) -> None:
optimizer_args=config.optim_args,
embedding_model=embedding_model,
global_batch_size=config.batch_size,
num_input_channels=no_auto(transform_instance.transform_args.num_channels),
)
train_helpers.load_checkpoint(
checkpoint=config.checkpoint,
Expand Down
13 changes: 11 additions & 2 deletions src/lightly_train/_commands/train_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def get_transform_args(
logger.debug(f"Getting transform args for method '{method}'.")
logger.debug(f"Using additional transform arguments {transform_args}.")
if isinstance(transform_args, MethodTransformArgs):
transform_args.resolve_auto()
return transform_args

method_cls = method_helpers.get_method_cls(method)
Expand All @@ -60,9 +61,15 @@ def get_transform_args(
if transform_args is None:
# We need to typeignore here because a MethodTransformArgs might not have
# defaults for all fields, while its children do.
return transform_args_cls() # type: ignore[call-arg]
transform_args = transform_args_cls() # type: ignore[call-arg]
else:
transform_args = validate.pydantic_model_validate(
transform_args_cls, transform_args
)

return validate.pydantic_model_validate(transform_args_cls, transform_args)
transform_args.resolve_auto()
transform_args.resolve_incompatible()
return transform_args


def get_transform(
Expand Down Expand Up @@ -326,13 +333,15 @@ def get_method(
optimizer_args: OptimizerArgs,
embedding_model: EmbeddingModel,
global_batch_size: int,
num_input_channels: int,
) -> Method:
logger.debug(f"Getting method for '{method_cls.__name__}'")
return method_cls(
method_args=method_args,
optimizer_args=optimizer_args,
embedding_model=embedding_model,
global_batch_size=global_batch_size,
num_input_channels=num_input_channels,
)


Expand Down
14 changes: 13 additions & 1 deletion src/lightly_train/_commands/train_task_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@
TrainModelArgs,
)
from lightly_train._train_task_state import TrainTaskState
from lightly_train._transforms.semantic_segmentation_transform import (
SemanticSegmentationTransform,
)
from lightly_train._transforms.task_transform import (
TaskTransform,
TaskTransformArgs,
Expand Down Expand Up @@ -229,17 +232,24 @@ def get_transform_args(
train_transform_args_cls, transform_args
)
train_transform_args.resolve_auto()
train_transform_args.resolve_incompatible()

# Take defaults from train transform.
val_args_dict = train_transform_args.model_dump(
include={"image_size": True, "normalize": True, "ignore_index": True}
include={
"image_size": True,
"normalize": True,
"ignore_index": True,
"num_channels": True,
}
)
# Overwrite with user provided val args.
val_args_dict.update(val_args)
val_transform_args = validate.pydantic_model_validate(
val_transform_args_cls, val_args_dict
)
val_transform_args.resolve_auto()
val_transform_args.resolve_incompatible()

logger.debug(
f"Resolved train transform args {pretty_format_args(train_transform_args.model_dump())}"
Expand Down Expand Up @@ -408,6 +418,8 @@ def get_dataset(
image_info = dataset_args.list_image_info()

dataset_cls = dataset_args.get_dataset_cls()
# TODO(Guarin, 08/25): Relax this when we add object detection.
assert isinstance(transform, SemanticSegmentationTransform)
return dataset_cls(
dataset_args=dataset_args,
image_info=get_dataset_mmap_file(
Expand Down
21 changes: 16 additions & 5 deletions src/lightly_train/_data/image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,31 @@ def __init__(
image_dir: Path | None,
image_filenames: Sequence[ImageFilename],
transform: Transform,
num_channels: int,
mask_dir: Path | None = None,
):
self.image_dir = image_dir
self.image_filenames = image_filenames
self.mask_dir = mask_dir
self.transform = transform
self.num_channels = num_channels

try:
self.image_mode = ImageMode(Env.LIGHTLY_TRAIN_IMAGE_MODE.value)
except ValueError:
image_mode = (
None
if Env.LIGHTLY_TRAIN_IMAGE_MODE.value is None
else ImageMode(Env.LIGHTLY_TRAIN_IMAGE_MODE.value)
)
if image_mode is None:
image_mode = (
ImageMode.RGB if self.num_channels == 3 else ImageMode.UNCHANGED
)

if image_mode not in (ImageMode.RGB, ImageMode.UNCHANGED):
raise ValueError(
f'Invalid image mode: {Env.LIGHTLY_TRAIN_IMAGE_MODE.name}="{Env.LIGHTLY_TRAIN_IMAGE_MODE.value}". '
"Supported modes are 'RGB' and 'UNCHANGED'."
f"Invalid image mode: '{image_mode}'. "
f"Supported modes are '{[ImageMode.RGB.value, ImageMode.UNCHANGED.value]}'."
)
self.image_mode = image_mode

def __getitem__(self, idx: int) -> DatasetItem:
filename = self.image_filenames[idx]
Expand Down
38 changes: 24 additions & 14 deletions src/lightly_train/_data/mask_semantic_segmentation_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
from lightly_train._data.file_helpers import ImageMode
from lightly_train._data.task_data_args import TaskDataArgs
from lightly_train._env import Env
from lightly_train._transforms.task_transform import TaskTransform
from lightly_train._transforms.semantic_segmentation_transform import (
SemanticSegmentationTransform,
SemanticSegmentationTransformArgs,
)
from lightly_train.types import (
BinaryMasksDict,
MaskSemanticSegmentationDatasetItem,
Expand All @@ -39,7 +42,7 @@ def __init__(
self,
dataset_args: MaskSemanticSegmentationDatasetArgs,
image_info: Sequence[dict[str, str]],
transform: TaskTransform,
transform: SemanticSegmentationTransform,
):
self.args = dataset_args
self.filepaths = image_info
Expand All @@ -50,20 +53,27 @@ def __init__(
self.class_mapping = self.get_class_mapping()
self.valid_classes = torch.tensor(list(self.class_mapping.keys()))

image_mode = Env.LIGHTLY_TRAIN_IMAGE_MODE.value
if image_mode not in ("RGB", "UNCHANGED"):
transform_args = transform.transform_args
assert isinstance(transform_args, SemanticSegmentationTransformArgs)

image_mode = (
None
if Env.LIGHTLY_TRAIN_IMAGE_MODE.value is None
else ImageMode(Env.LIGHTLY_TRAIN_IMAGE_MODE.value)
)
if image_mode is None:
image_mode = (
ImageMode.RGB
if transform_args.num_channels == 3
else ImageMode.UNCHANGED
)

if image_mode not in (ImageMode.RGB, ImageMode.UNCHANGED):
raise ValueError(
f'Invalid image mode: {Env.LIGHTLY_TRAIN_IMAGE_MODE.name}="{image_mode}". '
"Supported modes are 'RGB' and 'UNCHANGED'."
f"Invalid image mode: '{image_mode}'. "
f"Supported modes are '{[ImageMode.RGB.value, ImageMode.UNCHANGED.value]}'."
)
# Convert string to enum value
if image_mode == "RGB":
self.image_mode = ImageMode.RGB
elif image_mode == "UNCHANGED":
self.image_mode = ImageMode.UNCHANGED
else:
# This should not happen due to the check above, but added for type safety
raise ValueError(f"Unexpected image mode: {image_mode}")
self.image_mode = image_mode

def is_mask_valid(self, mask: Tensor) -> bool:
# Check if at least one value in the mask is in the valid classes.
Expand Down
4 changes: 2 additions & 2 deletions src/lightly_train/_embedding/embedding_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ class EmbeddingTransform:
def __init__(
self,
image_size: int | tuple[int, int],
mean: tuple[float, float, float],
std: tuple[float, float, float],
mean: tuple[float, ...],
std: tuple[float, ...],
):
if isinstance(image_size, int):
image_size = (image_size, image_size)
Expand Down
4 changes: 2 additions & 2 deletions src/lightly_train/_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,9 @@ class Env:
)
# Mode in which images are loaded. This can be "RGB" to load images in RGB or
# "UNCHANGED" to load images in their original format without any conversion.
LIGHTLY_TRAIN_IMAGE_MODE: EnvVar[str] = EnvVar(
LIGHTLY_TRAIN_IMAGE_MODE: EnvVar[str | None] = EnvVar(
name="LIGHTLY_TRAIN_IMAGE_MODE",
_default="RGB",
_default=None,
_type=str,
)
LIGHTLY_TRAIN_MASK_DIR: EnvVar[Path | None] = EnvVar(
Expand Down
2 changes: 2 additions & 0 deletions src/lightly_train/_methods/densecl/densecl.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,14 @@ def __init__(
optimizer_args: OptimizerArgs,
embedding_model: EmbeddingModel,
global_batch_size: int,
num_input_channels: int,
):
super().__init__(
method_args=method_args,
optimizer_args=optimizer_args,
embedding_model=embedding_model,
global_batch_size=global_batch_size,
num_input_channels=num_input_channels,
)
self.method_args = method_args
self.query_encoder = DenseCLEncoder(
Expand Down
2 changes: 2 additions & 0 deletions src/lightly_train/_methods/densecl/densecl_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from __future__ import annotations

from pydantic import Field
from typing_extensions import Literal

from lightly_train._transforms.transform import (
ChannelDropArgs,
Expand Down Expand Up @@ -55,6 +56,7 @@ class DenseCLGaussianBlurArgs(GaussianBlurArgs):
class DenseCLTransformArgs(MethodTransformArgs):
image_size: tuple[int, int] = Field(default=(224, 224), strict=False)
channel_drop: ChannelDropArgs | None = None
num_channels: int | Literal["auto"] = "auto"
random_resize: DenseCLRandomResizeArgs | None = Field(
default_factory=DenseCLRandomResizeArgs
)
Expand Down
4 changes: 4 additions & 0 deletions src/lightly_train/_methods/detcon/detcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,14 @@ def __init__(
optimizer_args: OptimizerArgs,
embedding_model: EmbeddingModel,
global_batch_size: int,
num_input_channels: int,
) -> None:
super().__init__(
method_args=method_args,
optimizer_args=optimizer_args,
embedding_model=embedding_model,
global_batch_size=global_batch_size,
num_input_channels=num_input_channels,
)
self.method_args = method_args
self.embedding_model = embedding_model
Expand Down Expand Up @@ -268,12 +270,14 @@ def __init__(
optimizer_args: OptimizerArgs,
embedding_model: EmbeddingModel,
global_batch_size: int,
num_input_channels: int,
) -> None:
super().__init__(
method_args=method_args,
optimizer_args=optimizer_args,
embedding_model=embedding_model,
global_batch_size=global_batch_size,
num_input_channels=num_input_channels,
)
self.method_args = method_args

Expand Down
4 changes: 4 additions & 0 deletions src/lightly_train/_methods/detcon/detcon_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#
from __future__ import annotations

from typing import Literal

from pydantic import Field

from lightly_train._configs.config import PydanticConfig
Expand Down Expand Up @@ -58,6 +60,7 @@ class DetConSView1TransformArgs(PydanticConfig):
class DetConSTransformArgs(MethodTransformArgs):
image_size: tuple[int, int] = Field(default=(224, 224), strict=False)
channel_drop: ChannelDropArgs | None = None
num_channels: int | Literal["auto"] = "auto"
random_resize: RandomResizeArgs | None = Field(default_factory=RandomResizeArgs)
random_flip: RandomFlipArgs | None = Field(default_factory=RandomFlipArgs)
random_rotation: RandomRotationArgs | None = None
Expand Down Expand Up @@ -109,6 +112,7 @@ class DetConBView1TransformArgs(PydanticConfig):
class DetConBTransformArgs(MethodTransformArgs):
image_size: tuple[int, int] = Field(default=(224, 224), strict=False)
channel_drop: ChannelDropArgs | None = None
num_channels: int | Literal["auto"] = "auto"
random_resize: RandomResizeArgs | None = Field(default_factory=RandomResizeArgs)
random_flip: RandomFlipArgs | None = Field(default_factory=RandomFlipArgs)
random_rotation: RandomRotationArgs | None = None
Expand Down
2 changes: 2 additions & 0 deletions src/lightly_train/_methods/dino/dino.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,14 @@ def __init__(
optimizer_args: OptimizerArgs,
embedding_model: EmbeddingModel,
global_batch_size: int,
num_input_channels: int,
):
super().__init__(
method_args=method_args,
optimizer_args=optimizer_args,
embedding_model=embedding_model,
global_batch_size=global_batch_size,
num_input_channels=num_input_channels,
)
self.method_args = method_args
self.teacher_embedding_model = embedding_model
Expand Down
3 changes: 3 additions & 0 deletions src/lightly_train/_methods/dino/dino_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#
from __future__ import annotations

from typing import Literal

from pydantic import Field

from lightly_train._configs.config import PydanticConfig
Expand Down Expand Up @@ -99,6 +101,7 @@ class DINOTransformArgs(MethodTransformArgs):
# https://github.com/facebookresearch/dino#resnet-50-and-other-convnets-trainings
image_size: tuple[int, int] = Field(default=(224, 224), strict=False)
channel_drop: ChannelDropArgs | None = None
num_channels: int | Literal["auto"] = "auto"
random_resize: DINORandomResizeArgs | None = Field(
default_factory=DINORandomResizeArgs
)
Expand Down
Loading
Loading