Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .github/actions/pkg-create/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ runs:
using: "composite"
steps:
- name: Create package 📦
# python setup.py clean
run: python -m build --verbose
run: uv build --verbose
shell: bash

- name: Check package 📦
Expand Down
4 changes: 1 addition & 3 deletions .github/workflows/check-precommit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,7 @@ jobs:

- name: Run pre-commit 🤖
id: precommit
run: |
uv pip install -q pre-commit
pre-commit run --all-files
run: uvx pre-commit run --all-files

- name: Minimize uv cache
run: uv cache prune --ci
Expand Down
27 changes: 27 additions & 0 deletions .github/workflows/check-ruff.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
name: Check Ruff linting

on:
workflow_call:
inputs:
ruff-version:
description: "Ruff version to use, defaults to latest"
default: ""
required: false
type: string

defaults:
run:
shell: bash

jobs:
ruff:
runs-on: ubuntu-latest
steps:
- name: Checkout 🛎️
uses: actions/checkout@v6

- name: Run Ruff 🐍
uses: astral-sh/ruff-action@v3
with:
version: ${{ inputs.ruff-version }}
args: "check --output-format=github"
3 changes: 3 additions & 0 deletions .github/workflows/ci-use-checks.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ jobs:
actions-ref: ${{ github.sha }} # use local version
extra-typing: "typing"

check-ruff:
uses: ./.github/workflows/check-ruff.yml

check-precommit:
uses: ./.github/workflows/check-precommit.yml

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ requires = [
]

[tool.ruff]
target-version = "py39"
target-version = "py310"
line-length = 120
format.preview = true
lint.select = [
Expand Down
5 changes: 2 additions & 3 deletions scripts/adjust-torch-versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import os
import re
import sys
from typing import Optional


def _determine_torchaudio(torch_version: str) -> str:
Expand Down Expand Up @@ -136,7 +135,7 @@ def find_latest(ver: str) -> dict[str, str]:
}


def adjust(requires: list[str], pytorch_version: Optional[str] = None) -> list[str]:
def adjust(requires: list[str], pytorch_version: str | None = None) -> list[str]:
"""Adjust the versions to be paired within pytorch ecosystem.

>>> from pprint import pprint
Expand Down Expand Up @@ -184,7 +183,7 @@ def _offset_print(reqs: list[str], offset: str = "\t|\t") -> str:
return os.linesep.join(reqs)


def main(requirements_path: str, torch_version: Optional[str] = None) -> None:
def main(requirements_path: str, torch_version: str | None = None) -> None:
"""The main entry point with mapping to the CLI for positional arguments only."""
# rU - universal line ending - https://stackoverflow.com/a/2717154/4521646
with open(requirements_path, encoding="utf8") as fopen:
Expand Down
7 changes: 3 additions & 4 deletions src/lightning_utilities/cli/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import warnings
from collections.abc import Sequence
from pprint import pprint
from typing import Union

REQUIREMENT_ROOT = "requirements.txt"
REQUIREMENT_FILES_ALL: list = glob.glob(os.path.join("requirements", "*.txt"))
Expand All @@ -19,7 +18,7 @@


def prune_packages_in_requirements(
packages: Union[str, Sequence[str]], req_files: Union[str, Sequence[str]] = REQUIREMENT_FILES_ALL
packages: str | Sequence[str], req_files: str | Sequence[str] = REQUIREMENT_FILES_ALL
) -> None:
"""Remove one or more packages from the specified requirement files.

Expand Down Expand Up @@ -101,7 +100,7 @@ def _replace_min_req_in_pyproject_toml(proj_file: str = "pyproject.toml") -> Non
f.write(tomlkit.dumps(doc))


def replace_oldest_version(req_files: Union[str, Sequence[str]] = REQUIREMENT_FILES_ALL) -> None:
def replace_oldest_version(req_files: str | Sequence[str] = REQUIREMENT_FILES_ALL) -> None:
"""Convert minimal version specifiers (>=) to pinned ones (==) in the given requirement files.

Supports plain *.txt requirements and pyproject.toml files. Unsupported file types trigger a warning.
Expand Down Expand Up @@ -178,7 +177,7 @@ def _replace_package_name_in_pyproject_toml(proj_file: str, old_package: str, ne


def replace_package_in_requirements(
old_package: str, new_package: str, req_files: Union[str, Sequence[str]] = REQUIREMENT_FILES_ALL
old_package: str, new_package: str, req_files: str | Sequence[str] = REQUIREMENT_FILES_ALL
) -> None:
"""Rename a package across multiple requirement files while keeping version constraints intact.

Expand Down
20 changes: 10 additions & 10 deletions src/lightning_utilities/core/apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
#
import dataclasses
from collections import OrderedDict, defaultdict
from collections.abc import Mapping, Sequence
from collections.abc import Callable, Mapping, Sequence
from copy import deepcopy
from functools import cached_property
from typing import Any, Callable, Optional, Union
from typing import Any


def is_namedtuple(obj: object) -> bool:
Expand All @@ -28,10 +28,10 @@ def is_dataclass_instance(obj: object) -> bool:

def apply_to_collection(
data: Any,
dtype: Union[type, Any, tuple[Union[type, Any]]],
dtype: type | Any | tuple[type | Any],
function: Callable,
*args: Any,
wrong_dtype: Optional[Union[type, tuple[type, ...]]] = None,
wrong_dtype: type | tuple[type, ...] | None = None,
include_none: bool = True,
allow_frozen: bool = False,
**kwargs: Any,
Expand Down Expand Up @@ -89,10 +89,10 @@ def apply_to_collection(

def _apply_to_collection_slow(
data: Any,
dtype: Union[type, Any, tuple[Union[type, Any]]],
dtype: type | Any | tuple[type | Any],
function: Callable,
*args: Any,
wrong_dtype: Optional[Union[type, tuple[type, ...]]] = None,
wrong_dtype: type | tuple[type, ...] | None = None,
include_none: bool = True,
allow_frozen: bool = False,
**kwargs: Any,
Expand Down Expand Up @@ -191,12 +191,12 @@ def _apply_to_collection_slow(


def apply_to_collections(
data1: Optional[Any],
data2: Optional[Any],
dtype: Union[type, Any, tuple[Union[type, Any]]],
data1: Any | None,
data2: Any | None,
dtype: type | Any | tuple[type | Any],
function: Callable,
*args: Any,
wrong_dtype: Optional[Union[type, tuple[type]]] = None,
wrong_dtype: type | tuple[type] | None = None,
**kwargs: Any,
) -> Any:
"""Zip two collections and apply a function to items of a certain dtype.
Expand Down
4 changes: 1 addition & 3 deletions src/lightning_utilities/core/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
#
import warnings
from enum import Enum
from typing import Optional

from typing_extensions import Literal
from typing import Literal, Optional


class StrEnum(str, Enum):
Expand Down
9 changes: 5 additions & 4 deletions src/lightning_utilities/core/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
import importlib
import os
import warnings
from collections.abc import Callable
from functools import lru_cache
from importlib.metadata import PackageNotFoundError, distribution
from importlib.metadata import version as _version
from importlib.util import find_spec
from types import ModuleType
from typing import Any, Callable, Optional, TypeVar
from typing import Any, TypeVar

from packaging.requirements import Requirement
from packaging.version import InvalidVersion, Version
Expand Down Expand Up @@ -117,7 +118,7 @@ class RequirementCache:

"""

def __init__(self, requirement: Optional[str] = None, module: Optional[str] = None) -> None:
def __init__(self, requirement: str | None = None, module: str | None = None) -> None:
if not (requirement or module):
raise ValueError("At least one arguments need to be set.")
self.requirement = requirement
Expand Down Expand Up @@ -262,7 +263,7 @@ class LazyModule(ModuleType):

"""

def __init__(self, module_name: str, callback: Optional[Callable] = None) -> None:
def __init__(self, module_name: str, callback: Callable | None = None) -> None:
super().__init__(module_name)
self._module: Any = None
self._callback = callback
Expand Down Expand Up @@ -294,7 +295,7 @@ def _import_module(self) -> None:
self.__dict__.update(self._module.__dict__)


def lazy_import(module_name: str, callback: Optional[Callable] = None) -> LazyModule:
def lazy_import(module_name: str, callback: Callable | None = None) -> LazyModule:
"""Return a proxy module object that will lazily import the given module the first time it is used.

Example usage:
Expand Down
17 changes: 9 additions & 8 deletions src/lightning_utilities/core/rank_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@

import logging
import warnings
from collections.abc import Callable
from functools import wraps
from typing import Any, Callable, Optional, TypeVar, Union
from typing import Any, TypeVar

from typing_extensions import ParamSpec, overload

Expand All @@ -18,14 +19,14 @@


@overload
def rank_zero_only(fn: Callable[P, T]) -> Callable[P, Optional[T]]: ...
def rank_zero_only(fn: Callable[P, T]) -> Callable[P, T | None]: ...


@overload
def rank_zero_only(fn: Callable[P, T], default: T) -> Callable[P, T]: ...


def rank_zero_only(fn: Callable[P, T], default: Optional[T] = None) -> Callable[P, Optional[T]]:
def rank_zero_only(fn: Callable[P, T], default: T | None = None) -> Callable[P, T | None]:
"""Decorator to run the wrapped function only on global rank 0.

Set ``rank_zero_only.rank`` before use. On non-zero ranks, the function is skipped and the provided
Expand All @@ -34,7 +35,7 @@ def rank_zero_only(fn: Callable[P, T], default: Optional[T] = None) -> Callable[
"""

@wraps(fn)
def wrapped_fn(*args: P.args, **kwargs: P.kwargs) -> Optional[T]:
def wrapped_fn(*args: P.args, **kwargs: P.kwargs) -> T | None:
rank = getattr(rank_zero_only, "rank", None)
if rank is None:
raise RuntimeError("The `rank_zero_only.rank` needs to be set before use")
Expand Down Expand Up @@ -67,26 +68,26 @@ def rank_zero_info(*args: Any, stacklevel: int = 4, **kwargs: Any) -> None:
_info(*args, stacklevel=stacklevel, **kwargs)


def _warn(message: Union[str, Warning], stacklevel: int = 2, **kwargs: Any) -> None:
def _warn(message: str | Warning, stacklevel: int = 2, **kwargs: Any) -> None:
warnings.warn(message, stacklevel=stacklevel, **kwargs)


@rank_zero_only
def rank_zero_warn(message: Union[str, Warning], stacklevel: int = 4, **kwargs: Any) -> None:
def rank_zero_warn(message: str | Warning, stacklevel: int = 4, **kwargs: Any) -> None:
"""Emit warn-level messages only on global rank 0."""
_warn(message, stacklevel=stacklevel, **kwargs)


rank_zero_deprecation_category = DeprecationWarning


def rank_zero_deprecation(message: Union[str, Warning], stacklevel: int = 5, **kwargs: Any) -> None:
def rank_zero_deprecation(message: str | Warning, stacklevel: int = 5, **kwargs: Any) -> None:
"""Emit a deprecation warning only on global rank 0."""
category = kwargs.pop("category", rank_zero_deprecation_category)
rank_zero_warn(message, stacklevel=stacklevel, category=category, **kwargs)


def rank_prefixed_message(message: str, rank: Optional[int]) -> str:
def rank_prefixed_message(message: str, rank: int | None) -> str:
"""Add a ``[rank: X]`` prefix to the message if ``rank`` is provided; otherwise return the message unchanged."""
if rank is not None:
# specify the rank of the process being logged
Expand Down
5 changes: 2 additions & 3 deletions src/lightning_utilities/docs/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import re
import sys
from collections.abc import Iterable
from typing import Optional, Union


def _transform_changelog(path_in: str, path_out: str) -> None:
Expand Down Expand Up @@ -97,7 +96,7 @@ def _load_pypi_versions(package_name: str) -> list[str]:
return sorted(versions, key=Version)


def _update_link_based_imported_package(link: str, pkg_ver: str, version_digits: Optional[int]) -> str:
def _update_link_based_imported_package(link: str, pkg_ver: str, version_digits: int | None) -> str:
"""Resolve a ``{package.version}`` placeholder in a link using the latest available version.

Args:
Expand Down Expand Up @@ -128,7 +127,7 @@ def _update_link_based_imported_package(link: str, pkg_ver: str, version_digits:
def adjust_linked_external_docs(
source_link: str,
target_link: str,
browse_folder: Union[str, Iterable[str]],
browse_folder: str | Iterable[str],
file_extensions: Iterable[str] = (".rst", ".py"),
version_digits: int = 2,
) -> None:
Expand Down
8 changes: 4 additions & 4 deletions src/lightning_utilities/install/requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
import re
from collections.abc import Iterable, Iterator
from pathlib import Path
from typing import Any, Optional, Union
from typing import Any

from packaging.requirements import Requirement
from packaging.version import Version


def _yield_lines(strs: Union[str, Iterable[str]]) -> Iterator[str]:
def _yield_lines(strs: str | Iterable[str]) -> Iterator[str]:
"""Yield non-empty, non-comment lines from a string or iterable of strings.

Adapted from pkg_resources.yield_lines.
Expand Down Expand Up @@ -45,7 +45,7 @@ class _RequirementWithComment(Requirement):

strict_string = "# strict"

def __init__(self, *args: Any, comment: str = "", pip_argument: Optional[str] = None, **kwargs: Any) -> None:
def __init__(self, *args: Any, comment: str = "", pip_argument: str | None = None, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.comment = comment
if not (pip_argument is None or pip_argument): # sanity check that it's not an empty str
Expand Down Expand Up @@ -110,7 +110,7 @@ def adjust(self, unfreeze: str) -> str:
return out


def _parse_requirements(strs: Union[str, Iterable[str]]) -> Iterator[_RequirementWithComment]:
def _parse_requirements(strs: str | Iterable[str]) -> Iterator[_RequirementWithComment]:
r"""Adapted from ``pkg_resources.parse_requirements`` to include comments and pip arguments.

Parses a sequence or string of requirement lines, preserving trailing comments and associating any
Expand Down
3 changes: 1 addition & 2 deletions src/lightning_utilities/test/warning.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
import warnings
from collections.abc import Generator
from contextlib import contextmanager
from typing import Optional


@contextmanager
def no_warning_call(expected_warning: type[Warning] = Warning, match: Optional[str] = None) -> Generator:
def no_warning_call(expected_warning: type[Warning] = Warning, match: str | None = None) -> Generator:
"""Assert that no matching warning is emitted within the context.

Args:
Expand Down
Loading
Loading