diff --git a/api/envoy/extensions/bootstrap/reverse_tunnel/downstream_socket_interface/v3/BUILD b/api/envoy/extensions/bootstrap/reverse_tunnel/downstream_socket_interface/v3/BUILD index 5f552f08145c..504c6c70514a 100644 --- a/api/envoy/extensions/bootstrap/reverse_tunnel/downstream_socket_interface/v3/BUILD +++ b/api/envoy/extensions/bootstrap/reverse_tunnel/downstream_socket_interface/v3/BUILD @@ -5,5 +5,8 @@ load("@envoy_api//bazel:api_build_system.bzl", "api_proto_package") licenses(["notice"]) # Apache 2 api_proto_package( - deps = ["@xds//udpa/annotations:pkg"], + deps = [ + "//envoy/config/core/v3:pkg", + "@xds//udpa/annotations:pkg", + ], ) diff --git a/api/envoy/extensions/bootstrap/reverse_tunnel/downstream_socket_interface/v3/downstream_reverse_connection_socket_interface.proto b/api/envoy/extensions/bootstrap/reverse_tunnel/downstream_socket_interface/v3/downstream_reverse_connection_socket_interface.proto index 72994c07973c..cb3ccb4c1d9c 100644 --- a/api/envoy/extensions/bootstrap/reverse_tunnel/downstream_socket_interface/v3/downstream_reverse_connection_socket_interface.proto +++ b/api/envoy/extensions/bootstrap/reverse_tunnel/downstream_socket_interface/v3/downstream_reverse_connection_socket_interface.proto @@ -2,6 +2,8 @@ syntax = "proto3"; package envoy.extensions.bootstrap.reverse_tunnel.downstream_socket_interface.v3; +import "envoy/config/core/v3/base.proto"; + import "udpa/annotations/status.proto"; option java_package = "io.envoyproxy.envoy.extensions.bootstrap.reverse_tunnel.downstream_socket_interface.v3"; @@ -22,6 +24,9 @@ message DownstreamReverseConnectionSocketInterface { // Request path used when issuing the HTTP reverse-connection handshake. Defaults to // "/reverse_connections/request". string request_path = 1; + + // Additional headers to include in the HTTP handshake request. + repeated config.core.v3.HeaderValueOption additional_headers = 2; } // Stat prefix to be used for downstream reverse connection socket interface stats. diff --git a/source/extensions/bootstrap/reverse_tunnel/downstream_socket_interface/rc_connection_wrapper.cc b/source/extensions/bootstrap/reverse_tunnel/downstream_socket_interface/rc_connection_wrapper.cc index e35ffcc870f4..c1a1603ff71b 100644 --- a/source/extensions/bootstrap/reverse_tunnel/downstream_socket_interface/rc_connection_wrapper.cc +++ b/source/extensions/bootstrap/reverse_tunnel/downstream_socket_interface/rc_connection_wrapper.cc @@ -139,6 +139,9 @@ std::string RCConnectionWrapper::connect(const std::string& src_tenant_id, headers->addCopy(cluster_hdr, std::string(cluster_id)); headers->addCopy(tenant_hdr, std::string(tenant_id)); headers->addCopy(upstream_cluster_hdr, cluster_name_); + for (const auto& [key, value] : parent_.additionalHeaders()) { + headers->addCopy(key, value); + } headers->setContentLength(0); // Encode via HTTP/1 codec. diff --git a/source/extensions/bootstrap/reverse_tunnel/downstream_socket_interface/reverse_connection_io_handle.h b/source/extensions/bootstrap/reverse_tunnel/downstream_socket_interface/reverse_connection_io_handle.h index e63ae65b1c6c..bdafe251c6ce 100644 --- a/source/extensions/bootstrap/reverse_tunnel/downstream_socket_interface/reverse_connection_io_handle.h +++ b/source/extensions/bootstrap/reverse_tunnel/downstream_socket_interface/reverse_connection_io_handle.h @@ -83,6 +83,8 @@ struct ReverseConnectionSocketConfig { std::string src_tenant_id; // Tenant identifier of local envoy instance. std::string request_path{ std::string(ReverseConnectionUtility::DEFAULT_REVERSE_TUNNEL_REQUEST_PATH)}; + std::vector> + additional_headers; // Additional headers for the handshake request. // TODO(basundhara-c): Add support for multiple remote clusters using the same // ReverseConnectionIOHandle. Currently, each ReverseConnectionIOHandle handles // reverse connections for a single upstream cluster since a different ReverseConnectionAddress @@ -305,6 +307,13 @@ class ReverseConnectionIOHandle : public Network::IoSocketHandleImpl, */ const std::string& requestPath() const { return config_.request_path; } + /** + * @return reference to the additional headers for the handshake request. + */ + const std::vector>& additionalHeaders() const { + return config_.additional_headers; + } + private: /** * Get time source for consistent time operations. diff --git a/source/extensions/bootstrap/reverse_tunnel/downstream_socket_interface/reverse_tunnel_initiator.cc b/source/extensions/bootstrap/reverse_tunnel/downstream_socket_interface/reverse_tunnel_initiator.cc index 7917b57516eb..1f73dd4844fa 100644 --- a/source/extensions/bootstrap/reverse_tunnel/downstream_socket_interface/reverse_tunnel_initiator.cc +++ b/source/extensions/bootstrap/reverse_tunnel/downstream_socket_interface/reverse_tunnel_initiator.cc @@ -127,6 +127,7 @@ ReverseTunnelInitiator::socket(Envoy::Network::Socket::Type socket_type, socket_config.remote_clusters.push_back(cluster_config); if (extension_ != nullptr) { socket_config.request_path = extension_->handshakeRequestPath(); + socket_config.additional_headers = extension_->handshakeAdditionalHeaders(); } // Pass config directly to helper method. diff --git a/source/extensions/bootstrap/reverse_tunnel/downstream_socket_interface/reverse_tunnel_initiator_extension.h b/source/extensions/bootstrap/reverse_tunnel/downstream_socket_interface/reverse_tunnel_initiator_extension.h index ab322edd10bc..4d52ddc8732e 100644 --- a/source/extensions/bootstrap/reverse_tunnel/downstream_socket_interface/reverse_tunnel_initiator_extension.h +++ b/source/extensions/bootstrap/reverse_tunnel/downstream_socket_interface/reverse_tunnel_initiator_extension.h @@ -44,6 +44,12 @@ class ReverseTunnelInitiatorExtension : public Server::BootstrapExtension, handshake_request_path_ = std::string(ReverseConnectionUtility::DEFAULT_REVERSE_TUNNEL_REQUEST_PATH); } + if (config.has_http_handshake()) { + for (const auto& header : config.http_handshake().additional_headers()) { + additional_headers_.emplace_back(Http::LowerCaseString(header.header().key()), + header.header().value()); + } + } ENVOY_LOG(debug, "ReverseTunnelInitiatorExtension: creating downstream reverse connection " "socket interface with stat_prefix: {}", @@ -105,6 +111,14 @@ class ReverseTunnelInitiatorExtension : public Server::BootstrapExtension, */ const std::string& handshakeRequestPath() const { return handshake_request_path_; } + /** + * @return reference to the additional headers to include in the handshake request. + */ + const std::vector>& + handshakeAdditionalHeaders() const { + return additional_headers_; + } + /** * Increment handshake stats for reverse tunnel connections (per-worker only). * Only tracks stats if enable_detailed_stats flag is true. @@ -134,6 +148,7 @@ class ReverseTunnelInitiatorExtension : public Server::BootstrapExtension, std::string stat_prefix_; // Reverse connection stats prefix bool enable_detailed_stats_{false}; std::string handshake_request_path_; + std::vector> additional_headers_; /** * Update per-worker connection stats for debugging purposes. diff --git a/test/extensions/bootstrap/reverse_tunnel/downstream_socket_interface/rc_connection_wrapper_test.cc b/test/extensions/bootstrap/reverse_tunnel/downstream_socket_interface/rc_connection_wrapper_test.cc index e808003ff887..4de8f325e7d2 100644 --- a/test/extensions/bootstrap/reverse_tunnel/downstream_socket_interface/rc_connection_wrapper_test.cc +++ b/test/extensions/bootstrap/reverse_tunnel/downstream_socket_interface/rc_connection_wrapper_test.cc @@ -305,6 +305,51 @@ TEST_F(RCConnectionWrapperTest, ConnectHttpHandshakeWithCustomRequestPath) { EXPECT_NE(encoded_request.find("GET /custom/handshake HTTP/1.1"), std::string::npos); } +// Test RCConnectionWrapper::connect() includes additional headers in the handshake request. +TEST_F(RCConnectionWrapperTest, ConnectHttpHandshakeWithAdditionalHeaders) { + auto mock_connection = std::make_unique>(); + + EXPECT_CALL(*mock_connection, addConnectionCallbacks(_)); + EXPECT_CALL(*mock_connection, addReadFilter(_)); + EXPECT_CALL(*mock_connection, connect()); + EXPECT_CALL(*mock_connection, id()).WillRepeatedly(Return(12345)); + EXPECT_CALL(*mock_connection, state()).WillRepeatedly(Return(Network::Connection::State::Open)); + + auto mock_address = std::make_shared("192.168.1.1", 8080); + auto mock_local_address = std::make_shared("127.0.0.1", 12345); + + EXPECT_CALL(*mock_connection, connectionInfoProvider()) + .WillRepeatedly(Invoke([mock_address, + mock_local_address]() -> const Network::ConnectionInfoProvider& { + static auto mock_provider = + std::make_unique(mock_local_address, mock_address); + return *mock_provider; + })); + + Buffer::OwnedImpl captured_buffer; + EXPECT_CALL(*mock_connection, write(_, _)) + .WillOnce(Invoke([&captured_buffer](Buffer::Instance& buffer, bool) { + captured_buffer.add(buffer); + buffer.drain(buffer.length()); + })); + + auto mock_host = std::make_shared>(); + + ReverseConnectionSocketConfig custom_config = createDefaultTestConfig(); + custom_config.additional_headers.emplace_back(Http::LowerCaseString("x-custom-auth"), "token123"); + custom_config.additional_headers.emplace_back(Http::LowerCaseString("x-request-id"), "abc-def"); + auto local_io_handle = createTestIOHandle(custom_config); + + RCConnectionWrapper wrapper(*local_io_handle, std::move(mock_connection), mock_host, + "test-cluster"); + + wrapper.connect("test-tenant", "test-cluster", "test-node"); + + const std::string encoded_request = captured_buffer.toString(); + EXPECT_NE(encoded_request.find("x-custom-auth: token123"), std::string::npos); + EXPECT_NE(encoded_request.find("x-request-id: abc-def"), std::string::npos); +} + // Test RCConnectionWrapper::connect() method with connection write failure. TEST_F(RCConnectionWrapperTest, ConnectHttpHandshakeWriteFailure) { // Create a mock connection that fails to write. diff --git a/test/extensions/bootstrap/reverse_tunnel/downstream_socket_interface/reverse_tunnel_initiator_extension_test.cc b/test/extensions/bootstrap/reverse_tunnel/downstream_socket_interface/reverse_tunnel_initiator_extension_test.cc index ad7fc6b72d07..7525fd1698ee 100644 --- a/test/extensions/bootstrap/reverse_tunnel/downstream_socket_interface/reverse_tunnel_initiator_extension_test.cc +++ b/test/extensions/bootstrap/reverse_tunnel/downstream_socket_interface/reverse_tunnel_initiator_extension_test.cc @@ -143,6 +143,28 @@ TEST_F(ReverseTunnelInitiatorExtensionTest, HandshakeRequestPathOverride) { EXPECT_EQ(custom_extension->handshakeRequestPath(), "/custom/handshake"); } +TEST_F(ReverseTunnelInitiatorExtensionTest, AdditionalHeadersDefaults) { + EXPECT_TRUE(extension_->handshakeAdditionalHeaders().empty()); +} + +TEST_F(ReverseTunnelInitiatorExtensionTest, AdditionalHeadersOverride) { + auto custom_config = config_; + auto* hdr1 = custom_config.mutable_http_handshake()->add_additional_headers(); + hdr1->mutable_header()->set_key("x-custom-auth"); + hdr1->mutable_header()->set_value("token123"); + auto* hdr2 = custom_config.mutable_http_handshake()->add_additional_headers(); + hdr2->mutable_header()->set_key("x-request-id"); + hdr2->mutable_header()->set_value("abc-def"); + auto custom_extension = + std::make_unique(context_, custom_config); + const auto& headers = custom_extension->handshakeAdditionalHeaders(); + ASSERT_EQ(headers.size(), 2); + EXPECT_EQ(headers[0].first.get(), "x-custom-auth"); + EXPECT_EQ(headers[0].second, "token123"); + EXPECT_EQ(headers[1].first.get(), "x-request-id"); + EXPECT_EQ(headers[1].second, "abc-def"); +} + TEST_F(ReverseTunnelInitiatorExtensionTest, OnServerInitialized) { // This should be a no-op. extension_->onServerInitialized(server_);