mirror of
				https://github.com/optim-enterprises-bv/kubernetes.git
				synced 2025-11-04 12:18:16 +00:00 
			
		
		
		
	Add rate limiting when calling STS assume role API
This commit is contained in:
		@@ -1207,7 +1207,7 @@ func init() {
 | 
				
			|||||||
			creds = credentials.NewChainCredentials(
 | 
								creds = credentials.NewChainCredentials(
 | 
				
			||||||
				[]credentials.Provider{
 | 
									[]credentials.Provider{
 | 
				
			||||||
					&credentials.EnvProvider{},
 | 
										&credentials.EnvProvider{},
 | 
				
			||||||
					provider,
 | 
										assumeRoleProvider(provider),
 | 
				
			||||||
				})
 | 
									})
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -0,0 +1,62 @@
 | 
				
			|||||||
 | 
					/*
 | 
				
			||||||
 | 
					Copyright 2014 The Kubernetes Authors.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					You may obtain a copy of the License at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					limitations under the License.
 | 
				
			||||||
 | 
					*/
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package aws
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"sync"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/aws/aws-sdk-go/aws/credentials"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const (
 | 
				
			||||||
 | 
						invalidateCredsAfter = 1 * time.Second
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// assumeRoleProviderWithRateLimiting makes sure we call the underlying provider only
 | 
				
			||||||
 | 
					// once after `invalidateCredsAfter` period
 | 
				
			||||||
 | 
					type assumeRoleProviderWithRateLimiting struct {
 | 
				
			||||||
 | 
						provider             credentials.Provider
 | 
				
			||||||
 | 
						invalidateCredsAfter time.Duration
 | 
				
			||||||
 | 
						sync.RWMutex
 | 
				
			||||||
 | 
						lastError        error
 | 
				
			||||||
 | 
						lastValue        credentials.Value
 | 
				
			||||||
 | 
						lastRetrieveTime time.Time
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func assumeRoleProvider(provider credentials.Provider) credentials.Provider {
 | 
				
			||||||
 | 
						return &assumeRoleProviderWithRateLimiting{provider: provider,
 | 
				
			||||||
 | 
							invalidateCredsAfter: invalidateCredsAfter}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (l *assumeRoleProviderWithRateLimiting) Retrieve() (credentials.Value, error) {
 | 
				
			||||||
 | 
						l.Lock()
 | 
				
			||||||
 | 
						defer l.Unlock()
 | 
				
			||||||
 | 
						if time.Since(l.lastRetrieveTime) < l.invalidateCredsAfter {
 | 
				
			||||||
 | 
							if l.lastError != nil {
 | 
				
			||||||
 | 
								return credentials.Value{}, l.lastError
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return l.lastValue, nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						l.lastValue, l.lastError = l.provider.Retrieve()
 | 
				
			||||||
 | 
						l.lastRetrieveTime = time.Now()
 | 
				
			||||||
 | 
						return l.lastValue, l.lastError
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (l *assumeRoleProviderWithRateLimiting) IsExpired() bool {
 | 
				
			||||||
 | 
						return l.provider.IsExpired()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -0,0 +1,132 @@
 | 
				
			|||||||
 | 
					/*
 | 
				
			||||||
 | 
					Copyright 2014 The Kubernetes Authors.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					You may obtain a copy of the License at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					limitations under the License.
 | 
				
			||||||
 | 
					*/
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package aws
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"reflect"
 | 
				
			||||||
 | 
						"sync"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/aws/aws-sdk-go/aws/credentials"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func Test_assumeRoleProviderWithRateLimiting_Retrieve(t *testing.T) {
 | 
				
			||||||
 | 
						type fields struct {
 | 
				
			||||||
 | 
							provider             credentials.Provider
 | 
				
			||||||
 | 
							invalidateCredsAfter time.Duration
 | 
				
			||||||
 | 
							RWMutex              sync.RWMutex
 | 
				
			||||||
 | 
							lastError            error
 | 
				
			||||||
 | 
							lastValue            credentials.Value
 | 
				
			||||||
 | 
							lastRetrieveTime     time.Time
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						tests := []struct {
 | 
				
			||||||
 | 
							name                       string
 | 
				
			||||||
 | 
							fields                     fields
 | 
				
			||||||
 | 
							want                       credentials.Value
 | 
				
			||||||
 | 
							wantProviderCalled         bool
 | 
				
			||||||
 | 
							sleepBeforeCallingProvider time.Duration
 | 
				
			||||||
 | 
							wantErr                    bool
 | 
				
			||||||
 | 
							wantErrString              string
 | 
				
			||||||
 | 
						}{{
 | 
				
			||||||
 | 
							name:               "Call assume role provider and verify access ID returned",
 | 
				
			||||||
 | 
							fields:             fields{provider: &fakeAssumeRoleProvider{accesskeyID: "fakeID"}},
 | 
				
			||||||
 | 
							want:               credentials.Value{AccessKeyID: "fakeID"},
 | 
				
			||||||
 | 
							wantProviderCalled: true,
 | 
				
			||||||
 | 
						}, {
 | 
				
			||||||
 | 
							name: "Immediate call to assume role API, shouldn't call the underlying provider and return the last value",
 | 
				
			||||||
 | 
							fields: fields{
 | 
				
			||||||
 | 
								provider:             &fakeAssumeRoleProvider{accesskeyID: "fakeID"},
 | 
				
			||||||
 | 
								invalidateCredsAfter: 100 * time.Millisecond,
 | 
				
			||||||
 | 
								lastValue:            credentials.Value{AccessKeyID: "fakeID1"},
 | 
				
			||||||
 | 
								lastRetrieveTime:     time.Now(),
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							want:                       credentials.Value{AccessKeyID: "fakeID1"},
 | 
				
			||||||
 | 
							wantProviderCalled:         false,
 | 
				
			||||||
 | 
							sleepBeforeCallingProvider: 10 * time.Millisecond,
 | 
				
			||||||
 | 
						}, {
 | 
				
			||||||
 | 
							name: "Assume role provider returns an error when trying to assume a role",
 | 
				
			||||||
 | 
							fields: fields{
 | 
				
			||||||
 | 
								provider:             &fakeAssumeRoleProvider{err: fmt.Errorf("can't assume fake role")},
 | 
				
			||||||
 | 
								invalidateCredsAfter: 10 * time.Millisecond,
 | 
				
			||||||
 | 
								lastRetrieveTime:     time.Now(),
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							wantProviderCalled:         true,
 | 
				
			||||||
 | 
							wantErr:                    true,
 | 
				
			||||||
 | 
							wantErrString:              "can't assume fake role",
 | 
				
			||||||
 | 
							sleepBeforeCallingProvider: 15 * time.Millisecond,
 | 
				
			||||||
 | 
						}, {
 | 
				
			||||||
 | 
							name: "Immediate call to assume role API, shouldn't call the underlying provider and return the last error value",
 | 
				
			||||||
 | 
							fields: fields{
 | 
				
			||||||
 | 
								provider:             &fakeAssumeRoleProvider{},
 | 
				
			||||||
 | 
								invalidateCredsAfter: 100 * time.Millisecond,
 | 
				
			||||||
 | 
								lastRetrieveTime:     time.Now(),
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							want:               credentials.Value{},
 | 
				
			||||||
 | 
							wantProviderCalled: false,
 | 
				
			||||||
 | 
							wantErr:            true,
 | 
				
			||||||
 | 
							wantErrString:      "can't assume fake role",
 | 
				
			||||||
 | 
						}, {
 | 
				
			||||||
 | 
							name: "Delayed call to assume role API, should call the underlying provider",
 | 
				
			||||||
 | 
							fields: fields{
 | 
				
			||||||
 | 
								provider:             &fakeAssumeRoleProvider{accesskeyID: "fakeID2"},
 | 
				
			||||||
 | 
								invalidateCredsAfter: 20 * time.Millisecond,
 | 
				
			||||||
 | 
								lastRetrieveTime:     time.Now(),
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							want:                       credentials.Value{AccessKeyID: "fakeID2"},
 | 
				
			||||||
 | 
							wantProviderCalled:         true,
 | 
				
			||||||
 | 
							sleepBeforeCallingProvider: 25 * time.Millisecond,
 | 
				
			||||||
 | 
						}}
 | 
				
			||||||
 | 
						for _, tt := range tests {
 | 
				
			||||||
 | 
							t.Run(tt.name, func(t *testing.T) {
 | 
				
			||||||
 | 
								l := &assumeRoleProviderWithRateLimiting{
 | 
				
			||||||
 | 
									provider:             tt.fields.provider,
 | 
				
			||||||
 | 
									invalidateCredsAfter: tt.fields.invalidateCredsAfter,
 | 
				
			||||||
 | 
									lastError:            tt.fields.lastError,
 | 
				
			||||||
 | 
									lastValue:            tt.fields.lastValue,
 | 
				
			||||||
 | 
									lastRetrieveTime:     tt.fields.lastRetrieveTime,
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								time.Sleep(tt.sleepBeforeCallingProvider)
 | 
				
			||||||
 | 
								got, err := l.Retrieve()
 | 
				
			||||||
 | 
								if (err != nil) != tt.wantErr && (tt.wantErr && reflect.DeepEqual(err, tt.wantErrString)) {
 | 
				
			||||||
 | 
									t.Errorf("assumeRoleProviderWithRateLimiting.Retrieve() error = %v, wantErr %v", err, tt.wantErr)
 | 
				
			||||||
 | 
									return
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								if !reflect.DeepEqual(got, tt.want) {
 | 
				
			||||||
 | 
									t.Errorf("assumeRoleProviderWithRateLimiting.Retrieve() got = %v, want %v", got, tt.want)
 | 
				
			||||||
 | 
									return
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								if tt.wantProviderCalled != tt.fields.provider.(*fakeAssumeRoleProvider).providerCalled {
 | 
				
			||||||
 | 
									t.Errorf("provider called %v, want %v", tt.fields.provider.(*fakeAssumeRoleProvider).providerCalled, tt.wantProviderCalled)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type fakeAssumeRoleProvider struct {
 | 
				
			||||||
 | 
						accesskeyID    string
 | 
				
			||||||
 | 
						err            error
 | 
				
			||||||
 | 
						providerCalled bool
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (f *fakeAssumeRoleProvider) Retrieve() (credentials.Value, error) {
 | 
				
			||||||
 | 
						f.providerCalled = true
 | 
				
			||||||
 | 
						return credentials.Value{AccessKeyID: f.accesskeyID}, f.err
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (f *fakeAssumeRoleProvider) IsExpired() bool { return true }
 | 
				
			||||||
		Reference in New Issue
	
	Block a user