mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-11-04 12:37:59 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			286 lines
		
	
	
		
			7.0 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			286 lines
		
	
	
		
			7.0 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
// Copyright (c) HashiCorp, Inc.
 | 
						|
// SPDX-License-Identifier: MPL-2.0
 | 
						|
 | 
						|
package api
 | 
						|
 | 
						|
import (
 | 
						|
	"errors"
 | 
						|
	"fmt"
 | 
						|
	"math/rand"
 | 
						|
	"reflect"
 | 
						|
	"testing"
 | 
						|
	"testing/quick"
 | 
						|
	"time"
 | 
						|
 | 
						|
	"github.com/go-test/deep"
 | 
						|
)
 | 
						|
 | 
						|
func TestRenewer_NewRenewer(t *testing.T) {
 | 
						|
	t.Parallel()
 | 
						|
 | 
						|
	client, err := NewClient(DefaultConfig())
 | 
						|
	if err != nil {
 | 
						|
		t.Fatal(err)
 | 
						|
	}
 | 
						|
 | 
						|
	cases := []struct {
 | 
						|
		name string
 | 
						|
		i    *RenewerInput
 | 
						|
		e    *Renewer
 | 
						|
		err  bool
 | 
						|
	}{
 | 
						|
		{
 | 
						|
			name: "nil",
 | 
						|
			i:    nil,
 | 
						|
			e:    nil,
 | 
						|
			err:  true,
 | 
						|
		},
 | 
						|
		{
 | 
						|
			name: "missing_secret",
 | 
						|
			i: &RenewerInput{
 | 
						|
				Secret: nil,
 | 
						|
			},
 | 
						|
			e:   nil,
 | 
						|
			err: true,
 | 
						|
		},
 | 
						|
		{
 | 
						|
			name: "default_grace",
 | 
						|
			i: &RenewerInput{
 | 
						|
				Secret: &Secret{},
 | 
						|
			},
 | 
						|
			e: &Renewer{
 | 
						|
				secret: &Secret{},
 | 
						|
			},
 | 
						|
			err: false,
 | 
						|
		},
 | 
						|
	}
 | 
						|
 | 
						|
	for _, tc := range cases {
 | 
						|
		t.Run(tc.name, func(t *testing.T) {
 | 
						|
			v, err := client.NewRenewer(tc.i)
 | 
						|
			if (err != nil) != tc.err {
 | 
						|
				t.Fatal(err)
 | 
						|
			}
 | 
						|
 | 
						|
			if v == nil {
 | 
						|
				return
 | 
						|
			}
 | 
						|
 | 
						|
			// Zero-out channels because reflect
 | 
						|
			v.client = nil
 | 
						|
			v.random = nil
 | 
						|
			v.doneCh = nil
 | 
						|
			v.renewCh = nil
 | 
						|
			v.stopCh = nil
 | 
						|
 | 
						|
			if diff := deep.Equal(tc.e, v); diff != nil {
 | 
						|
				t.Error(diff)
 | 
						|
			}
 | 
						|
		})
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestLifetimeWatcher(t *testing.T) {
 | 
						|
	t.Parallel()
 | 
						|
 | 
						|
	client, err := NewClient(DefaultConfig())
 | 
						|
	if err != nil {
 | 
						|
		t.Fatal(err)
 | 
						|
	}
 | 
						|
 | 
						|
	// Note that doRenewWithOptions starts its loop with an initial renewal.
 | 
						|
	// This has a big impact on the particulars of the following cases.
 | 
						|
 | 
						|
	renewedSecret := &Secret{}
 | 
						|
	var caseOneErrorCount int
 | 
						|
	var caseManyErrorsCount int
 | 
						|
	cases := []struct {
 | 
						|
		maxTestTime          time.Duration
 | 
						|
		name                 string
 | 
						|
		leaseDurationSeconds int
 | 
						|
		incrementSeconds     int
 | 
						|
		renew                renewFunc
 | 
						|
		expectError          error
 | 
						|
		expectRenewal        bool
 | 
						|
	}{
 | 
						|
		{
 | 
						|
			maxTestTime:          time.Second,
 | 
						|
			name:                 "no_error",
 | 
						|
			leaseDurationSeconds: 60,
 | 
						|
			incrementSeconds:     60,
 | 
						|
			renew: func(_ string, _ int) (*Secret, error) {
 | 
						|
				return renewedSecret, nil
 | 
						|
			},
 | 
						|
			expectError:   nil,
 | 
						|
			expectRenewal: true,
 | 
						|
		},
 | 
						|
		{
 | 
						|
			maxTestTime:          time.Second,
 | 
						|
			name:                 "short_increment_duration",
 | 
						|
			leaseDurationSeconds: 60,
 | 
						|
			incrementSeconds:     10,
 | 
						|
			renew: func(_ string, _ int) (*Secret, error) {
 | 
						|
				return renewedSecret, nil
 | 
						|
			},
 | 
						|
			expectError:   nil,
 | 
						|
			expectRenewal: true,
 | 
						|
		},
 | 
						|
		{
 | 
						|
			maxTestTime:          5 * time.Second,
 | 
						|
			name:                 "one_error",
 | 
						|
			leaseDurationSeconds: 15,
 | 
						|
			incrementSeconds:     15,
 | 
						|
			renew: func(_ string, _ int) (*Secret, error) {
 | 
						|
				if caseOneErrorCount == 0 {
 | 
						|
					caseOneErrorCount++
 | 
						|
					return nil, fmt.Errorf("renew failure")
 | 
						|
				}
 | 
						|
				return renewedSecret, nil
 | 
						|
			},
 | 
						|
			expectError:   nil,
 | 
						|
			expectRenewal: true,
 | 
						|
		},
 | 
						|
		{
 | 
						|
			maxTestTime:          15 * time.Second,
 | 
						|
			name:                 "many_errors",
 | 
						|
			leaseDurationSeconds: 15,
 | 
						|
			incrementSeconds:     15,
 | 
						|
			renew: func(_ string, _ int) (*Secret, error) {
 | 
						|
				if caseManyErrorsCount == 3 {
 | 
						|
					return renewedSecret, nil
 | 
						|
				}
 | 
						|
				caseManyErrorsCount++
 | 
						|
				return nil, fmt.Errorf("renew failure")
 | 
						|
			},
 | 
						|
			expectError:   nil,
 | 
						|
			expectRenewal: true,
 | 
						|
		},
 | 
						|
		{
 | 
						|
			maxTestTime:          15 * time.Second,
 | 
						|
			name:                 "only_errors",
 | 
						|
			leaseDurationSeconds: 15,
 | 
						|
			incrementSeconds:     15,
 | 
						|
			renew: func(_ string, _ int) (*Secret, error) {
 | 
						|
				return nil, fmt.Errorf("renew failure")
 | 
						|
			},
 | 
						|
			expectError:   nil,
 | 
						|
			expectRenewal: false,
 | 
						|
		},
 | 
						|
		{
 | 
						|
			maxTestTime:          15 * time.Second,
 | 
						|
			name:                 "negative_lease_duration",
 | 
						|
			leaseDurationSeconds: -15,
 | 
						|
			incrementSeconds:     15,
 | 
						|
			renew: func(_ string, _ int) (*Secret, error) {
 | 
						|
				return renewedSecret, nil
 | 
						|
			},
 | 
						|
			expectError:   nil,
 | 
						|
			expectRenewal: true,
 | 
						|
		},
 | 
						|
	}
 | 
						|
 | 
						|
	for _, tc := range cases {
 | 
						|
		t.Run(tc.name, func(t *testing.T) {
 | 
						|
			v, err := client.NewLifetimeWatcher(&LifetimeWatcherInput{
 | 
						|
				Secret: &Secret{
 | 
						|
					LeaseDuration: tc.leaseDurationSeconds,
 | 
						|
				},
 | 
						|
				Increment: tc.incrementSeconds,
 | 
						|
			})
 | 
						|
			if err != nil {
 | 
						|
				t.Fatal(err)
 | 
						|
			}
 | 
						|
 | 
						|
			doneCh := make(chan error, 1)
 | 
						|
			go func() {
 | 
						|
				doneCh <- v.doRenewWithOptions(false, false,
 | 
						|
					tc.leaseDurationSeconds, "myleaseID", tc.renew, time.Second)
 | 
						|
			}()
 | 
						|
			defer v.Stop()
 | 
						|
 | 
						|
			receivedRenewal := false
 | 
						|
			receivedDone := false
 | 
						|
		ChannelLoop:
 | 
						|
			for {
 | 
						|
				select {
 | 
						|
				case <-time.After(tc.maxTestTime):
 | 
						|
					t.Fatalf("renewal didn't happen")
 | 
						|
				case r := <-v.RenewCh():
 | 
						|
					if !tc.expectRenewal {
 | 
						|
						t.Fatal("expected no renewals")
 | 
						|
					}
 | 
						|
					if r.Secret != renewedSecret {
 | 
						|
						t.Fatalf("expected secret %v, got %v", renewedSecret, r.Secret)
 | 
						|
					}
 | 
						|
					receivedRenewal = true
 | 
						|
					if !receivedDone {
 | 
						|
						continue ChannelLoop
 | 
						|
					}
 | 
						|
					break ChannelLoop
 | 
						|
				case err := <-doneCh:
 | 
						|
					receivedDone = true
 | 
						|
					if tc.expectError != nil && !errors.Is(err, tc.expectError) {
 | 
						|
						t.Fatalf("expected error %q, got: %v", tc.expectError, err)
 | 
						|
					}
 | 
						|
					if tc.expectError == nil && err != nil {
 | 
						|
						t.Fatalf("expected no error, got: %v", err)
 | 
						|
					}
 | 
						|
					if tc.expectRenewal && !receivedRenewal {
 | 
						|
						// We might have received the stop before the renew call on the channel.
 | 
						|
						continue ChannelLoop
 | 
						|
					}
 | 
						|
					break ChannelLoop
 | 
						|
				}
 | 
						|
			}
 | 
						|
 | 
						|
			if tc.expectRenewal && !receivedRenewal {
 | 
						|
				t.Fatalf("expected at least one renewal, got none.")
 | 
						|
			}
 | 
						|
		})
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// TestCalcSleepPeriod uses property based testing to evaluate the calculateSleepDuration
 | 
						|
// function of LifeTimeWatchers, but also incidentally tests "calculateGrace".
 | 
						|
// This is on account of "calculateSleepDuration" performing the "calculateGrace"
 | 
						|
// function in particular instances.
 | 
						|
// Both of these functions support the vital functionality of the LifeTimeWatcher
 | 
						|
// and therefore should be tested rigorously.
 | 
						|
func TestCalcSleepPeriod(t *testing.T) {
 | 
						|
	c := quick.Config{
 | 
						|
		MaxCount: 10000,
 | 
						|
		Values: func(values []reflect.Value, r *rand.Rand) {
 | 
						|
			leaseDuration := r.Int63()
 | 
						|
			priorDuration := r.Int63n(leaseDuration)
 | 
						|
			remainingLeaseDuration := r.Int63n(priorDuration)
 | 
						|
			increment := r.Int63n(remainingLeaseDuration)
 | 
						|
 | 
						|
			values[0] = reflect.ValueOf(r)
 | 
						|
			values[1] = reflect.ValueOf(time.Duration(leaseDuration))
 | 
						|
			values[2] = reflect.ValueOf(time.Duration(priorDuration))
 | 
						|
			values[3] = reflect.ValueOf(time.Duration(remainingLeaseDuration))
 | 
						|
			values[4] = reflect.ValueOf(time.Duration(increment))
 | 
						|
		},
 | 
						|
	}
 | 
						|
 | 
						|
	// tests that "calculateSleepDuration" will always return a value less than
 | 
						|
	// the remaining lease duration given a random leaseDuration, priorDuration, remainingLeaseDuration, and increment.
 | 
						|
	// Inputs are generated so that:
 | 
						|
	// leaseDuration > priorDuration > remainingLeaseDuration
 | 
						|
	// and remainingLeaseDuration > increment
 | 
						|
	if err := quick.Check(func(r *rand.Rand, leaseDuration, priorDuration, remainingLeaseDuration, increment time.Duration) bool {
 | 
						|
		lw := LifetimeWatcher{
 | 
						|
			grace:     0,
 | 
						|
			increment: int(increment.Seconds()),
 | 
						|
			random:    r,
 | 
						|
		}
 | 
						|
 | 
						|
		lw.calculateGrace(remainingLeaseDuration, increment)
 | 
						|
 | 
						|
		// ensure that we sleep for less than the remaining lease.
 | 
						|
		return lw.calculateSleepDuration(remainingLeaseDuration, priorDuration) < remainingLeaseDuration
 | 
						|
	}, &c); err != nil {
 | 
						|
		t.Error(err)
 | 
						|
	}
 | 
						|
}
 |