Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 52 additions & 90 deletions auth/access_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@ import (
"strings"

authnv1 "k8s.io/api/authentication/v1"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/errors"
"sigs.k8s.io/controller-runtime/pkg/client"

"github.com/fluxcd/pkg/cache"
)
Expand All @@ -45,49 +42,57 @@ func GetAccessToken(ctx context.Context, provider Provider, opts ...Option) (Tok
}

// Update access token fetcher for a service account if specified.
var serviceAccount *corev1.ServiceAccount
var providerIdentity string
var audiences []string
if o.ShouldGetServiceAccountToken() {
var saInfo *serviceAccountInfo
if o.ShouldGetServiceAccount() {
// Check the feature gate for object-level workload identity.
if !IsObjectLevelWorkloadIdentityEnabled() {
return nil, ErrObjectLevelWorkloadIdentityNotEnabled
}

// Fetch service account details.
var err error
saRef := client.ObjectKey{
Name: o.ServiceAccountName,
Namespace: o.ServiceAccountNamespace,
}
serviceAccount, audiences, providerIdentity, err =
getServiceAccountAndProviderInfo(ctx, provider, o.Client, saRef, opts...)
saInfo, err = getServiceAccountInfo(ctx, provider, o.Client, opts...)
if err != nil {
return nil, err
}

// Update the function to create an access token using the service account.
newAccessToken = func() (Token, error) {
// Check the feature gate for object-level workload identity.
if !IsObjectLevelWorkloadIdentityEnabled() {
return nil, ErrObjectLevelWorkloadIdentityNotEnabled
if saInfo.useServiceAccount {
newAccessToken = func() (Token, error) {
// Issue Kubernetes OIDC token for the service account.
tokenReq := &authnv1.TokenRequest{
Spec: authnv1.TokenRequestSpec{
Audiences: saInfo.audiences,
},
}
if err := o.Client.SubResource("token").Create(ctx, saInfo.obj, tokenReq); err != nil {
return nil, fmt.Errorf("failed to create kubernetes token for service account '%s/%s': %w",
saInfo.obj.Namespace, saInfo.obj.Name, err)
}
oidcToken := tokenReq.Status.Token

// Exchange the Kubernetes OIDC token for a provider access token.
token, err := provider.NewTokenForServiceAccount(ctx, oidcToken, *saInfo.obj, opts...)
if err != nil {
return nil, fmt.Errorf("failed to create provider access token for service account '%s/%s': %w",
saInfo.obj.Namespace, saInfo.obj.Name, err)
}

return token, nil
}
}
}

// Issue Kubernetes OIDC token for the service account.
tokenReq := &authnv1.TokenRequest{
Spec: authnv1.TokenRequestSpec{
Audiences: audiences,
},
}
if err := o.Client.SubResource("token").Create(ctx, serviceAccount, tokenReq); err != nil {
return nil, fmt.Errorf("failed to create kubernetes token for service account '%s/%s': %w",
serviceAccount.Namespace, serviceAccount.Name, err)
}
oidcToken := tokenReq.Status.Token

// Exchange the Kubernetes OIDC token for a provider access token.
token, err := provider.NewTokenForServiceAccount(ctx, oidcToken, *serviceAccount, opts...)
// Update access token fetcher for impersonation if supported by the provider.
if saInfo != nil && saInfo.providerIdentityForImpersonation != nil {
newNonImpersonatedToken := newAccessToken
newAccessToken = func() (Token, error) {
token, err := newNonImpersonatedToken()
if err != nil {
return nil, fmt.Errorf("failed to create provider access token for service account '%s/%s': %w",
serviceAccount.Namespace, serviceAccount.Name, err)
return nil, err
}

return token, nil
p := provider.(ProviderWithImpersonation)
return p.NewTokenForIdentity(ctx, token, saInfo.providerIdentityForImpersonation, opts...)
}
}

Expand All @@ -97,7 +102,7 @@ func GetAccessToken(ctx context.Context, provider Provider, opts ...Option) (Tok
}

// Build cache key.
cacheKey := buildAccessTokenCacheKey(provider, audiences, providerIdentity, serviceAccount, opts...)
cacheKey := buildAccessTokenCacheKey(provider, saInfo, opts...)

// Build involved object details.
kind := o.InvolvedObject.Kind
Expand All @@ -116,55 +121,7 @@ func GetAccessToken(ctx context.Context, provider Provider, opts ...Option) (Tok
return token, nil
}

func getServiceAccountAndProviderInfo(ctx context.Context, provider Provider, client client.Client,
key client.ObjectKey, opts ...Option) (*corev1.ServiceAccount, []string, string, error) {

var o Options
o.Apply(opts...)

defaultSA := getDefaultServiceAccount()
var setDefaultSA bool

// Apply multi-tenancy lockdown: use default service account when .serviceAccountName
// is not explicitly specified in the object. This results in Object-Level Workload Identity.
if key.Name == "" && defaultSA != "" {
key.Name = defaultSA
setDefaultSA = true
}

// Get service account.
var serviceAccount corev1.ServiceAccount
if err := client.Get(ctx, key, &serviceAccount); err != nil {
if errors.IsNotFound(err) && setDefaultSA {
return nil, nil, "", fmt.Errorf("failed to get service account '%s': %w",
key, ErrDefaultServiceAccountNotFound)
}
return nil, nil, "", fmt.Errorf("failed to get service account '%s': %w",
key, err)
}

// Get provider audience.
audiences := o.Audiences
if len(audiences) == 0 {
var err error
audiences, err = provider.GetAudiences(ctx, serviceAccount)
if err != nil {
return nil, nil, "", fmt.Errorf("failed to get provider audience: %w", err)
}
}

// Get provider identity.
providerIdentity, err := provider.GetIdentity(serviceAccount)
if err != nil {
return nil, nil, "", fmt.Errorf("failed to get provider identity from service account '%s/%s' annotations: %w",
key.Namespace, key.Name, err)
}

return &serviceAccount, audiences, providerIdentity, nil
}

func buildAccessTokenCacheKey(provider Provider, audiences []string, providerIdentity string,
serviceAccount *corev1.ServiceAccount, opts ...Option) string {
func buildAccessTokenCacheKey(provider Provider, saInfo *serviceAccountInfo, opts ...Option) string {

var o Options
o.Apply(opts...)
Expand All @@ -173,11 +130,16 @@ func buildAccessTokenCacheKey(provider Provider, audiences []string, providerIde

parts = append(parts, fmt.Sprintf("provider=%s", provider.GetName()))

if serviceAccount != nil {
parts = append(parts, fmt.Sprintf("providerIdentity=%s", providerIdentity))
parts = append(parts, fmt.Sprintf("serviceAccountName=%s", serviceAccount.Name))
parts = append(parts, fmt.Sprintf("serviceAccountNamespace=%s", serviceAccount.Namespace))
parts = append(parts, fmt.Sprintf("serviceAccountTokenAudiences=%s", strings.Join(audiences, ",")))
if saInfo != nil {
if saInfo.useServiceAccount {
parts = append(parts, fmt.Sprintf("serviceAccountName=%s", saInfo.obj.Name))
parts = append(parts, fmt.Sprintf("serviceAccountNamespace=%s", saInfo.obj.Namespace))
parts = append(parts, fmt.Sprintf("serviceAccountTokenAudiences=%s", strings.Join(saInfo.audiences, ",")))
parts = append(parts, fmt.Sprintf("providerIdentity=%s", saInfo.providerIdentity))
}
if saInfo.providerIdentityForImpersonation != nil {
parts = append(parts, fmt.Sprintf("providerIdentityForImpersonation=%s", saInfo.providerIdentityForImpersonation))
}
}

if len(o.Scopes) > 0 {
Expand Down
Loading
Loading