Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,8 @@ load("@envoy_api//bazel:api_build_system.bzl", "api_proto_package")
licenses(["notice"]) # Apache 2

api_proto_package(
deps = ["@xds//udpa/annotations:pkg"],
deps = [
"//envoy/config/core/v3:pkg",
"@xds//udpa/annotations:pkg",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ syntax = "proto3";

package envoy.extensions.bootstrap.reverse_tunnel.downstream_socket_interface.v3;

import "envoy/config/core/v3/base.proto";

import "udpa/annotations/status.proto";

option java_package = "io.envoyproxy.envoy.extensions.bootstrap.reverse_tunnel.downstream_socket_interface.v3";
Expand All @@ -22,6 +24,9 @@ message DownstreamReverseConnectionSocketInterface {
// Request path used when issuing the HTTP reverse-connection handshake. Defaults to
// "/reverse_connections/request".
string request_path = 1;

// Additional headers to include in the HTTP handshake request.
repeated config.core.v3.HeaderValueOption additional_headers = 2;
}

// Stat prefix to be used for downstream reverse connection socket interface stats.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,9 @@ std::string RCConnectionWrapper::connect(const std::string& src_tenant_id,
headers->addCopy(cluster_hdr, std::string(cluster_id));
headers->addCopy(tenant_hdr, std::string(tenant_id));
headers->addCopy(upstream_cluster_hdr, cluster_name_);
for (const auto& [key, value] : parent_.additionalHeaders()) {
headers->addCopy(key, value);
}
headers->setContentLength(0);

// Encode via HTTP/1 codec.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ struct ReverseConnectionSocketConfig {
std::string src_tenant_id; // Tenant identifier of local envoy instance.
std::string request_path{
std::string(ReverseConnectionUtility::DEFAULT_REVERSE_TUNNEL_REQUEST_PATH)};
std::vector<std::pair<Http::LowerCaseString, std::string>>
additional_headers; // Additional headers for the handshake request.
// TODO(basundhara-c): Add support for multiple remote clusters using the same
// ReverseConnectionIOHandle. Currently, each ReverseConnectionIOHandle handles
// reverse connections for a single upstream cluster since a different ReverseConnectionAddress
Expand Down Expand Up @@ -305,6 +307,13 @@ class ReverseConnectionIOHandle : public Network::IoSocketHandleImpl,
*/
const std::string& requestPath() const { return config_.request_path; }

/**
* @return reference to the additional headers for the handshake request.
*/
const std::vector<std::pair<Http::LowerCaseString, std::string>>& additionalHeaders() const {
return config_.additional_headers;
}

private:
/**
* Get time source for consistent time operations.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ ReverseTunnelInitiator::socket(Envoy::Network::Socket::Type socket_type,
socket_config.remote_clusters.push_back(cluster_config);
if (extension_ != nullptr) {
socket_config.request_path = extension_->handshakeRequestPath();
socket_config.additional_headers = extension_->handshakeAdditionalHeaders();
}

// Pass config directly to helper method.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ class ReverseTunnelInitiatorExtension : public Server::BootstrapExtension,
handshake_request_path_ =
std::string(ReverseConnectionUtility::DEFAULT_REVERSE_TUNNEL_REQUEST_PATH);
}
if (config.has_http_handshake()) {
for (const auto& header : config.http_handshake().additional_headers()) {
additional_headers_.emplace_back(Http::LowerCaseString(header.header().key()),
header.header().value());
}
}
ENVOY_LOG(debug,
"ReverseTunnelInitiatorExtension: creating downstream reverse connection "
"socket interface with stat_prefix: {}",
Expand Down Expand Up @@ -105,6 +111,14 @@ class ReverseTunnelInitiatorExtension : public Server::BootstrapExtension,
*/
const std::string& handshakeRequestPath() const { return handshake_request_path_; }

/**
* @return reference to the additional headers to include in the handshake request.
*/
const std::vector<std::pair<Http::LowerCaseString, std::string>>&
handshakeAdditionalHeaders() const {
return additional_headers_;
}

/**
* Increment handshake stats for reverse tunnel connections (per-worker only).
* Only tracks stats if enable_detailed_stats flag is true.
Expand Down Expand Up @@ -134,6 +148,7 @@ class ReverseTunnelInitiatorExtension : public Server::BootstrapExtension,
std::string stat_prefix_; // Reverse connection stats prefix
bool enable_detailed_stats_{false};
std::string handshake_request_path_;
std::vector<std::pair<Http::LowerCaseString, std::string>> additional_headers_;

/**
* Update per-worker connection stats for debugging purposes.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,51 @@ TEST_F(RCConnectionWrapperTest, ConnectHttpHandshakeWithCustomRequestPath) {
EXPECT_NE(encoded_request.find("GET /custom/handshake HTTP/1.1"), std::string::npos);
}

// Test RCConnectionWrapper::connect() includes additional headers in the handshake request.
TEST_F(RCConnectionWrapperTest, ConnectHttpHandshakeWithAdditionalHeaders) {
auto mock_connection = std::make_unique<NiceMock<Network::MockClientConnection>>();

EXPECT_CALL(*mock_connection, addConnectionCallbacks(_));
EXPECT_CALL(*mock_connection, addReadFilter(_));
EXPECT_CALL(*mock_connection, connect());
EXPECT_CALL(*mock_connection, id()).WillRepeatedly(Return(12345));
EXPECT_CALL(*mock_connection, state()).WillRepeatedly(Return(Network::Connection::State::Open));

auto mock_address = std::make_shared<Network::Address::Ipv4Instance>("192.168.1.1", 8080);
auto mock_local_address = std::make_shared<Network::Address::Ipv4Instance>("127.0.0.1", 12345);

EXPECT_CALL(*mock_connection, connectionInfoProvider())
.WillRepeatedly(Invoke([mock_address,
mock_local_address]() -> const Network::ConnectionInfoProvider& {
static auto mock_provider =
std::make_unique<Network::ConnectionInfoSetterImpl>(mock_local_address, mock_address);
return *mock_provider;
}));

Buffer::OwnedImpl captured_buffer;
EXPECT_CALL(*mock_connection, write(_, _))
.WillOnce(Invoke([&captured_buffer](Buffer::Instance& buffer, bool) {
captured_buffer.add(buffer);
buffer.drain(buffer.length());
}));

auto mock_host = std::make_shared<NiceMock<Upstream::MockHostDescription>>();

ReverseConnectionSocketConfig custom_config = createDefaultTestConfig();
custom_config.additional_headers.emplace_back(Http::LowerCaseString("x-custom-auth"), "token123");
custom_config.additional_headers.emplace_back(Http::LowerCaseString("x-request-id"), "abc-def");
auto local_io_handle = createTestIOHandle(custom_config);

RCConnectionWrapper wrapper(*local_io_handle, std::move(mock_connection), mock_host,
"test-cluster");

wrapper.connect("test-tenant", "test-cluster", "test-node");

const std::string encoded_request = captured_buffer.toString();
EXPECT_NE(encoded_request.find("x-custom-auth: token123"), std::string::npos);
EXPECT_NE(encoded_request.find("x-request-id: abc-def"), std::string::npos);
}

// Test RCConnectionWrapper::connect() method with connection write failure.
TEST_F(RCConnectionWrapperTest, ConnectHttpHandshakeWriteFailure) {
// Create a mock connection that fails to write.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,28 @@ TEST_F(ReverseTunnelInitiatorExtensionTest, HandshakeRequestPathOverride) {
EXPECT_EQ(custom_extension->handshakeRequestPath(), "/custom/handshake");
}

TEST_F(ReverseTunnelInitiatorExtensionTest, AdditionalHeadersDefaults) {
EXPECT_TRUE(extension_->handshakeAdditionalHeaders().empty());
}

TEST_F(ReverseTunnelInitiatorExtensionTest, AdditionalHeadersOverride) {
auto custom_config = config_;
auto* hdr1 = custom_config.mutable_http_handshake()->add_additional_headers();
hdr1->mutable_header()->set_key("x-custom-auth");
hdr1->mutable_header()->set_value("token123");
auto* hdr2 = custom_config.mutable_http_handshake()->add_additional_headers();
hdr2->mutable_header()->set_key("x-request-id");
hdr2->mutable_header()->set_value("abc-def");
auto custom_extension =
std::make_unique<ReverseTunnelInitiatorExtension>(context_, custom_config);
const auto& headers = custom_extension->handshakeAdditionalHeaders();
ASSERT_EQ(headers.size(), 2);
EXPECT_EQ(headers[0].first.get(), "x-custom-auth");
EXPECT_EQ(headers[0].second, "token123");
EXPECT_EQ(headers[1].first.get(), "x-request-id");
EXPECT_EQ(headers[1].second, "abc-def");
}

TEST_F(ReverseTunnelInitiatorExtensionTest, OnServerInitialized) {
// This should be a no-op.
extension_->onServerInitialized(server_);
Expand Down
Loading