Add error checking when creating seal.Access object (#24181)

This commit is contained in:
Divya Pola
2023-11-17 13:28:24 -06:00
committed by GitHub
parent 5415d3c8a1
commit 117118e2bd
3 changed files with 57 additions and 16 deletions

View File

@@ -2722,29 +2722,57 @@ func setSeal(c *ServerCommand, config *server.Config, infoKeys []string, info ma
case len(enabledSealWrappers) == 1 && containsShamir(enabledSealWrappers): case len(enabledSealWrappers) == 1 && containsShamir(enabledSealWrappers):
// The barrier seal is Shamir. If there are any disabled seals, then we put them all in the same // The barrier seal is Shamir. If there are any disabled seals, then we put them all in the same
// autoSeal. // autoSeal.
barrierSeal = vault.NewDefaultSeal(vaultseal.NewAccess(sealLogger, sealGenerationInfo, enabledSealWrappers)) a, err := vaultseal.NewAccess(sealLogger, sealGenerationInfo, enabledSealWrappers)
if err != nil {
return nil, err
}
barrierSeal = vault.NewDefaultSeal(a)
if len(disabledSealWrappers) > 0 { if len(disabledSealWrappers) > 0 {
unwrapSeal = vault.NewAutoSeal(vaultseal.NewAccess(sealLogger, sealGenerationInfo, disabledSealWrappers)) a, err = vaultseal.NewAccess(sealLogger, sealGenerationInfo, disabledSealWrappers)
if err != nil {
return nil, err
}
unwrapSeal = vault.NewAutoSeal(a)
} }
case len(disabledSealWrappers) == 1 && containsShamir(disabledSealWrappers): case len(disabledSealWrappers) == 1 && containsShamir(disabledSealWrappers):
// The unwrap seal is Shamir, we are migrating to an autoSeal. // The unwrap seal is Shamir, we are migrating to an autoSeal.
barrierSeal = vault.NewAutoSeal(vaultseal.NewAccess(sealLogger, sealGenerationInfo, enabledSealWrappers)) a, err := vaultseal.NewAccess(sealLogger, sealGenerationInfo, enabledSealWrappers)
unwrapSeal = vault.NewDefaultSeal(vaultseal.NewAccess(sealLogger, sealGenerationInfo, disabledSealWrappers)) if err != nil {
return nil, err
}
barrierSeal = vault.NewAutoSeal(a)
a, err = vaultseal.NewAccess(sealLogger, sealGenerationInfo, disabledSealWrappers)
if err != nil {
return nil, err
}
unwrapSeal = vault.NewDefaultSeal(a)
case server.IsMultisealSupported(): case server.IsMultisealSupported():
// We know we are not using Shamir seal, that we are not migrating away from one, and multi seal is supported, // We know we are not using Shamir seal, that we are not migrating away from one, and multi seal is supported,
// so just put enabled and disabled wrappers on the same seal Access // so just put enabled and disabled wrappers on the same seal Access
allSealWrappers := append(enabledSealWrappers, disabledSealWrappers...) allSealWrappers := append(enabledSealWrappers, disabledSealWrappers...)
barrierSeal = vault.NewAutoSeal(vaultseal.NewAccess(sealLogger, sealGenerationInfo, allSealWrappers)) a, err := vaultseal.NewAccess(sealLogger, sealGenerationInfo, allSealWrappers)
if err != nil {
return nil, err
}
barrierSeal = vault.NewAutoSeal(a)
if configuredSeals < len(enabledSealWrappers) { if configuredSeals < len(enabledSealWrappers) {
c.UI.Warn("WARNING: running with fewer than all configured seals during unseal. Will not be fully highly available until errors are corrected and Vault restarted.") c.UI.Warn("WARNING: running with fewer than all configured seals during unseal. Will not be fully highly available until errors are corrected and Vault restarted.")
} }
case len(enabledSealWrappers) == 1: case len(enabledSealWrappers) == 1:
// We may have multiple seals disabled, but we know Shamir is not one of them. // We may have multiple seals disabled, but we know Shamir is not one of them.
barrierSeal = vault.NewAutoSeal(vaultseal.NewAccess(sealLogger, sealGenerationInfo, enabledSealWrappers)) a, err := vaultseal.NewAccess(sealLogger, sealGenerationInfo, enabledSealWrappers)
if err != nil {
return nil, err
}
barrierSeal = vault.NewAutoSeal(a)
if len(disabledSealWrappers) > 0 { if len(disabledSealWrappers) > 0 {
unwrapSeal = vault.NewAutoSeal(vaultseal.NewAccess(sealLogger, sealGenerationInfo, disabledSealWrappers)) a, err = vaultseal.NewAccess(sealLogger, sealGenerationInfo, disabledSealWrappers)
if err != nil {
return nil, err
}
unwrapSeal = vault.NewAutoSeal(a)
} }
default: default:

View File

@@ -6327,7 +6327,7 @@ func TestGetSealBackendStatus(t *testing.T) {
}) })
} }
shamirSeal := NewDefaultSeal(seal.NewAccess(nil, a, err := seal.NewAccess(nil,
&seal.SealGenerationInfo{ &seal.SealGenerationInfo{
Generation: 1, Generation: 1,
Seals: []*configutil.KMS{{Type: wrapping.WrapperTypeShamir.String()}}, Seals: []*configutil.KMS{{Type: wrapping.WrapperTypeShamir.String()}},
@@ -6340,7 +6340,9 @@ func TestGetSealBackendStatus(t *testing.T) {
Configured: true, Configured: true,
}, },
}, },
)) )
require.NoError(t, err)
shamirSeal := NewDefaultSeal(a)
c := TestCoreWithSeal(t, shamirSeal, false) c := TestCoreWithSeal(t, shamirSeal, false)
keys, _, _ := TestCoreInitClusterWrapperSetup(t, c, nil) keys, _, _ := TestCoreInitClusterWrapperSetup(t, c, nil)

View File

@@ -36,8 +36,11 @@ const (
) )
var ( var (
ErrUnconfiguredWrapper = errors.New("unconfigured wrapper") ErrUnconfiguredWrapper = errors.New("unconfigured wrapper")
ErrNoHealthySeals = errors.New("no healthy seals!") ErrNoHealthySeals = errors.New("no healthy seals!")
ErrNoConfiguredSeals = errors.New("no configured seals")
ErrNoSealGenerationInfo = errors.New("no seal generation info")
ErrNoSeals = errors.New("no seals provided in the configuration")
) )
func (s StoredKeysSupport) String() string { func (s StoredKeysSupport) String() string {
@@ -322,15 +325,17 @@ type access struct {
var _ Access = (*access)(nil) var _ Access = (*access)(nil)
func NewAccess(logger hclog.Logger, sealGenerationInfo *SealGenerationInfo, sealWrappers []*SealWrapper) Access { func NewAccess(logger hclog.Logger, sealGenerationInfo *SealGenerationInfo, sealWrappers []*SealWrapper) (Access, error) {
if logger == nil { if logger == nil {
logger = hclog.NewNullLogger() logger = hclog.NewNullLogger()
} }
if sealGenerationInfo == nil { if sealGenerationInfo == nil {
panic("cannot create a seal.Access without a SealGenerationInfo") logger.Error("cannot create a seal.Access without a SealGenerationInfo")
return nil, ErrNoSealGenerationInfo
} }
if len(sealWrappers) == 0 { if len(sealWrappers) == 0 {
panic("cannot create a seal.Access without any seal wrappers") logger.Error("cannot create a seal.Access without any seal wrappers")
return nil, ErrNoSeals
} }
a := &access{ a := &access{
sealGenerationInfo: sealGenerationInfo, sealGenerationInfo: sealGenerationInfo,
@@ -341,9 +346,15 @@ func NewAccess(logger hclog.Logger, sealGenerationInfo *SealGenerationInfo, seal
a.wrappersByPriority[i] = sw a.wrappersByPriority[i] = sw
} }
configuredSealWrappers := a.GetConfiguredSealWrappersByPriority()
if len(configuredSealWrappers) == 0 {
a.logger.Error("cannot create a seal.Access without any configured seal wrappers")
return nil, ErrNoConfiguredSeals
}
sort.Slice(a.wrappersByPriority, func(i int, j int) bool { return a.wrappersByPriority[i].Priority < a.wrappersByPriority[j].Priority }) sort.Slice(a.wrappersByPriority, func(i int, j int) bool { return a.wrappersByPriority[i].Priority < a.wrappersByPriority[j].Priority })
return a return a, nil
} }
func NewAccessFromSealWrappers(logger hclog.Logger, generation uint64, rewrapped bool, sealWrappers []*SealWrapper) (Access, error) { func NewAccessFromSealWrappers(logger hclog.Logger, generation uint64, rewrapped bool, sealWrappers []*SealWrapper) (Access, error) {
@@ -363,7 +374,7 @@ func NewAccessFromSealWrappers(logger hclog.Logger, generation uint64, rewrapped
Name: sw.Name, Name: sw.Name,
}) })
} }
return NewAccess(logger, sealGenerationInfo, sealWrappers), nil return NewAccess(logger, sealGenerationInfo, sealWrappers)
} }
// NewAccessFromWrapper creates an enabled Access for a single wrapping.Wrapper. // NewAccessFromWrapper creates an enabled Access for a single wrapping.Wrapper.