Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 154 additions & 0 deletions libcudacxx/include/cuda/__memory_resource/any_resource.h
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,100 @@ synchronous_resource_ref<_Properties...> __as_resource_ref(resource_ref<_Propert
return __mr;
}

_CCCL_BEGIN_NAMESPACE_ABI_VER4_BUMP

//! @brief Attempts to cast the type-erased resource stored in \p __res to \c _Resource.
//!
//! @tparam _Resource The resource type to cast to.
//! @tparam _Properties The properties of the \c any_resource object.
//! @param __res Pointer to the \c any_resource object to cast.
//! @return A pointer to the contained \c _Resource object if the stored type is
//! exactly \c _Resource, or \c nullptr otherwise.
_CCCL_TEMPLATE(class _Resource, class... _Properties)
_CCCL_REQUIRES(mr::resource_with<_Resource, _Properties...>)
[[nodiscard]] _CCCL_HOST_API _Resource* any_cast(any_resource<_Properties...>* __res) noexcept
{
return ::cuda::__any_cast<_Resource>(static_cast<::cuda::__basic_any<__iasync_resource<_Properties...>>*>(__res));
}

//! @overload
_CCCL_TEMPLATE(class _Resource, class... _Properties)
_CCCL_REQUIRES(mr::resource_with<_Resource, _Properties...>)
[[nodiscard]] _CCCL_HOST_API const _Resource* any_cast(const any_resource<_Properties...>* __res) noexcept
{
return ::cuda::__any_cast<_Resource>(
static_cast<const ::cuda::__basic_any<__iasync_resource<_Properties...>>*>(__res));
}

//! @brief Attempts to cast the type-erased resource stored in \p __res to \c _Resource.
//!
//! @tparam _Resource The resource type to cast to.
//! @tparam _Properties The properties of the \c any_synchronous_resource object.
//! @param __res Pointer to the \c any_synchronous_resource object to cast.
//! @return A pointer to the contained \c _Resource object if the stored type is
//! exactly \c _Resource, or \c nullptr otherwise.
_CCCL_TEMPLATE(class _Resource, class... _Properties)
_CCCL_REQUIRES(mr::synchronous_resource_with<_Resource, _Properties...>)
[[nodiscard]] _CCCL_HOST_API _Resource* any_cast(any_synchronous_resource<_Properties...>* __res) noexcept
{
return ::cuda::__any_cast<_Resource>(static_cast<::cuda::__basic_any<__iresource<_Properties...>>*>(__res));
}

//! @overload
_CCCL_TEMPLATE(class _Resource, class... _Properties)
_CCCL_REQUIRES(mr::synchronous_resource_with<_Resource, _Properties...>)
[[nodiscard]] _CCCL_HOST_API const _Resource* any_cast(const any_synchronous_resource<_Properties...>* __res) noexcept
{
return ::cuda::__any_cast<_Resource>(static_cast<const ::cuda::__basic_any<__iresource<_Properties...>>*>(__res));
}

//! @brief Attempts to cast the type-erased resource referenced by \p __res to \c _Resource.
//!
//! @tparam _Resource The resource type to cast to.
//! @tparam _Properties The properties of the \c resource_ref object.
//! @param __res Pointer to the \c resource_ref object to cast.
//! @return A pointer to the referenced \c _Resource object if the referenced type is
//! exactly \c _Resource, or \c nullptr otherwise.
_CCCL_TEMPLATE(class _Resource, class... _Properties)
_CCCL_REQUIRES(mr::resource_with<_Resource, _Properties...>)
[[nodiscard]] _CCCL_HOST_API _Resource* any_cast(resource_ref<_Properties...>* __res) noexcept
{
return ::cuda::__any_cast<_Resource>(static_cast<::cuda::__basic_any<__iasync_resource<_Properties...>&>*>(__res));
}

//! @overload
_CCCL_TEMPLATE(class _Resource, class... _Properties)
_CCCL_REQUIRES(mr::resource_with<_Resource, _Properties...>)
[[nodiscard]] _CCCL_HOST_API const _Resource* any_cast(const resource_ref<_Properties...>* __res) noexcept
{
return ::cuda::__any_cast<_Resource>(
static_cast<const ::cuda::__basic_any<__iasync_resource<_Properties...>&>*>(__res));
}

//! @brief Attempts to cast the type-erased resource referenced by \p __res to \c _Resource.
//!
//! @tparam _Resource The resource type to cast to.
//! @tparam _Properties The properties of the \c synchronous_resource_ref object.
//! @param __res Pointer to the \c synchronous_resource_ref object to cast.
//! @return A pointer to the referenced \c _Resource object if the referenced type is
//! exactly \c _Resource, or \c nullptr otherwise.
_CCCL_TEMPLATE(class _Resource, class... _Properties)
_CCCL_REQUIRES(mr::synchronous_resource_with<_Resource, _Properties...>)
[[nodiscard]] _CCCL_HOST_API _Resource* any_cast(synchronous_resource_ref<_Properties...>* __res) noexcept
{
return ::cuda::__any_cast<_Resource>(static_cast<::cuda::__basic_any<__iresource<_Properties...>&>*>(__res));
}

//! @overload
_CCCL_TEMPLATE(class _Resource, class... _Properties)
_CCCL_REQUIRES(mr::synchronous_resource_with<_Resource, _Properties...>)
[[nodiscard]] _CCCL_HOST_API const _Resource* any_cast(const synchronous_resource_ref<_Properties...>* __res) noexcept
{
return ::cuda::__any_cast<_Resource>(static_cast<const ::cuda::__basic_any<__iresource<_Properties...>&>*>(__res));
}

_CCCL_END_NAMESPACE_ABI_VER4_BUMP

# else // ^^^ !_CCCL_DOXYGEN_INVOKED ^^^ / vvv _CCCL_DOXYGEN_INVOKED vvv

enum class _ResourceKind
Expand Down Expand Up @@ -893,6 +987,66 @@ using synchronous_resource_ref = basic_resource_ref<_ResourceKind::_Synchronous,
template <class... _Properties>
using resource_ref = basic_resource_ref<_ResourceKind::_Asynchronous, _Properties...>;

//! @rst
//! .. _libcudacxx-memory-resource-any-cast:
//!
//! ``any_cast`` for type-erased memory resources
//! -----------------------------------------------
//!
//! ``any_cast`` attempts to obtain a pointer to the concrete resource type stored
//! in a type-erased memory resource wrapper. It is analogous to
//! ``std::any_cast`` for ``std::any``.
//!
//! When called with a pointer to an ``any_resource``, ``any_synchronous_resource``,
//! ``resource_ref``, or ``synchronous_resource_ref``, ``any_cast`` returns a pointer
//! to the contained or referenced object if its type matches ``_Resource`` exactly,
//! or ``nullptr`` otherwise.
//!
//! @endrst
//!
//! @tparam _Resource The concrete resource type to cast to.
//! @tparam _Properties The properties of the type-erased resource wrapper.
//! @param __res Pointer to the type-erased resource wrapper.
//! @return A pointer to the contained or referenced \c _Resource if the stored type
//! is exactly \c _Resource, or \c nullptr otherwise. Also returns \c nullptr if
//! \p __res is \c nullptr or empty.

//! @brief Attempts to obtain a pointer to the resource of type \c _Resource
//! stored in an \c any_resource.
template <class _Resource, class... _Properties>
_Resource* any_cast(any_resource<_Properties...>* __res) noexcept;

//! @overload
template <class _Resource, class... _Properties>
const _Resource* any_cast(const any_resource<_Properties...>* __res) noexcept;

//! @brief Attempts to obtain a pointer to the resource of type \c _Resource
//! stored in an \c any_synchronous_resource.
template <class _Resource, class... _Properties>
_Resource* any_cast(any_synchronous_resource<_Properties...>* __res) noexcept;

//! @overload
template <class _Resource, class... _Properties>
const _Resource* any_cast(const any_synchronous_resource<_Properties...>* __res) noexcept;

//! @brief Attempts to obtain a pointer to the resource of type \c _Resource
//! referenced by a \c resource_ref.
template <class _Resource, class... _Properties>
_Resource* any_cast(resource_ref<_Properties...>* __ref) noexcept;

//! @overload
template <class _Resource, class... _Properties>
const _Resource* any_cast(const resource_ref<_Properties...>* __ref) noexcept;

//! @brief Attempts to obtain a pointer to the resource of type \c _Resource
//! referenced by a \c synchronous_resource_ref.
template <class _Resource, class... _Properties>
_Resource* any_cast(synchronous_resource_ref<_Properties...>* __ref) noexcept;

//! @overload
template <class _Resource, class... _Properties>
const _Resource* any_cast(const synchronous_resource_ref<_Properties...>* __ref) noexcept;

# endif // _CCCL_DOXYGEN_INVOKED

template <class _Tp>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
//===----------------------------------------------------------------------===//
//
// Part of CUDA Experimental in CUDA C++ Core Libraries,
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

#include <cuda/memory_resource>

#include <testing.cuh>

#ifndef __CUDA_ARCH__

struct resource_a
{
void* allocate(cuda::stream_ref, size_t, size_t)
{
return nullptr;
}
void deallocate(cuda::stream_ref, void*, size_t, size_t) noexcept {}
void* allocate_sync(size_t, size_t)
{
return nullptr;
}
void deallocate_sync(void*, size_t, size_t) noexcept {}
bool operator==(const resource_a&) const noexcept
{
return true;
}
bool operator!=(const resource_a&) const noexcept
{
return false;
}
friend void get_property(const resource_a&, cuda::mr::host_accessible) noexcept {}
};

struct resource_b
{
void* allocate(cuda::stream_ref, size_t, size_t)
{
return nullptr;
}
void deallocate(cuda::stream_ref, void*, size_t, size_t) noexcept {}
void* allocate_sync(size_t, size_t)
{
return nullptr;
}
void deallocate_sync(void*, size_t, size_t) noexcept {}
bool operator==(const resource_b&) const noexcept
{
return true;
}
bool operator!=(const resource_b&) const noexcept
{
return false;
}
friend void get_property(const resource_b&, cuda::mr::host_accessible) noexcept {}
};

struct sync_resource_a
{
void* allocate_sync(size_t, size_t)
{
return nullptr;
}
void deallocate_sync(void*, size_t, size_t) noexcept {}
bool operator==(const sync_resource_a&) const noexcept
{
return true;
}
bool operator!=(const sync_resource_a&) const noexcept
{
return false;
}
friend void get_property(const sync_resource_a&, cuda::mr::host_accessible) noexcept {}
};

struct sync_resource_base
{
void* allocate_sync(size_t, size_t)
{
return nullptr;
}
void deallocate_sync(void*, size_t, size_t) noexcept {}
bool operator==(const sync_resource_base&) const noexcept
{
return true;
}
bool operator!=(const sync_resource_base&) const noexcept
{
return false;
}
friend void get_property(const sync_resource_base&, cuda::mr::host_accessible) noexcept {}
};

struct sync_resource_derived : sync_resource_base
{};

struct sync_resource_b
{
void* allocate_sync(size_t, size_t)
{
return nullptr;
}
void deallocate_sync(void*, size_t, size_t) noexcept {}
bool operator==(const sync_resource_b&) const noexcept
{
return true;
}
bool operator!=(const sync_resource_b&) const noexcept
{
return false;
}
friend void get_property(const sync_resource_b&, cuda::mr::host_accessible) noexcept {}
};

TEST_CASE("any_cast on any_resource", "[any_resource]")
{
cuda::mr::any_resource<cuda::mr::host_accessible> mr{resource_a{}};

CHECK(cuda::mr::any_cast<resource_a>(&mr) != nullptr);
CHECK(cuda::mr::any_cast<resource_b>(&mr) == nullptr);

const cuda::mr::any_resource<cuda::mr::host_accessible> const_mr{resource_a{}};
CHECK(cuda::mr::any_cast<resource_a>(&const_mr) != nullptr);
CHECK(cuda::mr::any_cast<resource_b>(&const_mr) == nullptr);

// nullptr input returns nullptr
cuda::mr::any_resource<cuda::mr::host_accessible>* null_mr = nullptr;
CHECK(cuda::mr::any_cast<resource_a>(null_mr) == nullptr);

// empty (default-constructed) any_resource returns nullptr
cuda::mr::any_resource<cuda::mr::host_accessible> empty_mr;
CHECK(cuda::mr::any_cast<resource_a>(&empty_mr) == nullptr);

// empty (moved-from) any_resource returns nullptr
cuda::mr::any_resource<cuda::mr::host_accessible> moved_from{resource_a{}};
auto moved_to = std::move(moved_from);
CHECK(cuda::mr::any_cast<resource_a>(&moved_from) == nullptr);
CHECK(cuda::mr::any_cast<resource_a>(&moved_to) != nullptr);
}

TEST_CASE("any_cast on any_synchronous_resource", "[any_resource]")
{
cuda::mr::any_synchronous_resource<cuda::mr::host_accessible> mr{sync_resource_a{}};

CHECK(cuda::mr::any_cast<sync_resource_a>(&mr) != nullptr);
CHECK(cuda::mr::any_cast<sync_resource_b>(&mr) == nullptr);

const cuda::mr::any_synchronous_resource<cuda::mr::host_accessible> const_mr{sync_resource_a{}};
CHECK(cuda::mr::any_cast<sync_resource_a>(&const_mr) != nullptr);
CHECK(cuda::mr::any_cast<sync_resource_b>(&const_mr) == nullptr);

// nullptr input returns nullptr
cuda::mr::any_synchronous_resource<cuda::mr::host_accessible>* null_mr = nullptr;
CHECK(cuda::mr::any_cast<sync_resource_a>(null_mr) == nullptr);
}

TEST_CASE("any_cast on resource_ref", "[any_resource]")
{
resource_a ra{};
cuda::mr::resource_ref<cuda::mr::host_accessible> ref{ra};

CHECK(cuda::mr::any_cast<resource_a>(&ref) != nullptr);
CHECK(cuda::mr::any_cast<resource_a>(&ref) == &ra);
CHECK(cuda::mr::any_cast<resource_b>(&ref) == nullptr);

const cuda::mr::resource_ref<cuda::mr::host_accessible> const_ref{ra};
CHECK(cuda::mr::any_cast<resource_a>(&const_ref) != nullptr);
CHECK(cuda::mr::any_cast<resource_b>(&const_ref) == nullptr);

// nullptr input returns nullptr
cuda::mr::resource_ref<cuda::mr::host_accessible>* null_ref = nullptr;
CHECK(cuda::mr::any_cast<resource_a>(null_ref) == nullptr);
}

TEST_CASE("any_cast on synchronous_resource_ref", "[any_resource]")
{
sync_resource_a ra{};
cuda::mr::synchronous_resource_ref<cuda::mr::host_accessible> ref{ra};

CHECK(cuda::mr::any_cast<sync_resource_a>(&ref) != nullptr);
CHECK(cuda::mr::any_cast<sync_resource_a>(&ref) == &ra);
CHECK(cuda::mr::any_cast<sync_resource_b>(&ref) == nullptr);

const cuda::mr::synchronous_resource_ref<cuda::mr::host_accessible> const_ref{ra};
CHECK(cuda::mr::any_cast<sync_resource_a>(&const_ref) != nullptr);
CHECK(cuda::mr::any_cast<sync_resource_b>(&const_ref) == nullptr);

// nullptr input returns nullptr
cuda::mr::synchronous_resource_ref<cuda::mr::host_accessible>* null_ref = nullptr;
CHECK(cuda::mr::any_cast<sync_resource_a>(null_ref) == nullptr);
}

TEST_CASE("any_cast requires exact type match (no derived-to-base)", "[any_resource]")
{
// any_cast performs an exact type match, like std::any_cast.
// Casting to a base class when a derived class is stored returns nullptr.
cuda::mr::any_synchronous_resource<cuda::mr::host_accessible> mr{sync_resource_derived{}};
CHECK(cuda::mr::any_cast<sync_resource_derived>(&mr) != nullptr);
CHECK(cuda::mr::any_cast<sync_resource_base>(&mr) == nullptr);

sync_resource_derived derived{};
cuda::mr::synchronous_resource_ref<cuda::mr::host_accessible> ref{derived};
CHECK(cuda::mr::any_cast<sync_resource_derived>(&ref) != nullptr);
CHECK(cuda::mr::any_cast<sync_resource_base>(&ref) == nullptr);
}

#endif // __CUDA_ARCH__