Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 41 additions & 25 deletions src/rfdetr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,36 @@
if os.environ.get("PYTORCH_ENABLE_MPS_FALLBACK") is None:
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

from rfdetr.detr import (
RFDETRBase,
RFDETRLarge,
RFDETRLargeDeprecated,
RFDETRMedium,
RFDETRNano,
RFDETRSeg2XLarge,
RFDETRSegLarge,
RFDETRSegMedium,
RFDETRSegNano,
RFDETRSegPreview,
RFDETRSegSmall,
RFDETRSegXLarge,
RFDETRSmall,
)
# Model classes resolved lazily via __getattr__ to avoid eagerly importing
# the entire training/inference stack (torch, transformers, peft, scipy, etc.)
_DETR_EXPORTS = {
"RFDETRBase",
"RFDETRLarge",
"RFDETRLargeDeprecated",
"RFDETRMedium",
"RFDETRNano",
"RFDETRSeg2XLarge",
"RFDETRSegLarge",
"RFDETRSegMedium",
"RFDETRSegNano",
"RFDETRSegPreview",
"RFDETRSegSmall",
"RFDETRSegXLarge",
"RFDETRSmall",
}

_PLUS_EXPORTS = {
"RFDETR2XLarge",
"RFDETRXLarge",
}

__all__ = [
"RFDETRNano",
"RFDETRSmall",
"RFDETRSmall",
"RFDETRBase",
"RFDETRMedium",
"RFDETRLarge",
"RFDETRLargeDeprecated",
"RFDETRSegNano",
"RFDETRSegSmall",
"RFDETRSegMedium",
Expand All @@ -38,27 +47,34 @@
"RFDETRSeg2XLarge",
]


def __getattr__(name: str):
"""Resolve plus-only exports lazily, raising only on explicit access."""
_PLUS_EXPORTS = {
"RFDETR2XLarge",
"RFDETRXLarge"
}
"""Resolve model class exports lazily to keep ``import rfdetr`` fast."""
if name in _DETR_EXPORTS:
from rfdetr import detr as _detr_module

# Cache all detr exports at once so __getattr__ is only called once.
for export_name in _DETR_EXPORTS:
globals()[export_name] = getattr(_detr_module, export_name)
return globals()[name]

if name in _PLUS_EXPORTS:
from rfdetr.platform import _INSTALL_MSG
from rfdetr.platform import models as _platform_models

# Cache the resolved symbol to avoid repeated attribute lookups.
if hasattr(_platform_models, name):
value = getattr(_platform_models, name)
globals()[name] = value
# Keep __all__ in sync with dynamically resolved exports.
if name not in __all__:
__all__.append(name)
return value

# The name is expected to be plus-only; raise a clear install hint.
raise ImportError(_INSTALL_MSG.format(name="platform model downloads"))

# Non-plus names fall back to the default attribute error.
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")


def __dir__():
"""Include lazily-resolved model classes in dir() output."""
module_attrs = list(globals().keys())
return module_attrs + list(_DETR_EXPORTS) + list(_PLUS_EXPORTS)
Loading