Files
vault/builtin/logical/transit/path_export_test.go
Steven Clark 92682f33ce Address a panic when exporting RSA public keys in transit (#24054)
* Address a panic export RSA public keys in transit

 - When attempting to export the public key for an RSA key that
   we only have a private key for, the export panics with a nil
   deference.
 - Add additional tests around Transit key exporting

* Add cl
2023-11-14 09:40:37 -05:00

632 lines
17 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package transit
import (
"context"
cryptoRand "crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"fmt"
"reflect"
"strconv"
"strings"
"testing"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/builtin/logical/pki"
vaulthttp "github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/vault"
"github.com/hashicorp/vault/sdk/logical"
"github.com/stretchr/testify/require"
)
func TestTransit_Export_Unknown_ExportType(t *testing.T) {
t.Parallel()
b, storage := createBackendWithSysView(t)
keyType := "ed25519"
req := &logical.Request{
Storage: storage,
Operation: logical.UpdateOperation,
Path: "keys/foo",
Data: map[string]interface{}{
"exportable": true,
"type": keyType,
},
}
_, err := b.HandleRequest(context.Background(), req)
if err != nil {
t.Fatalf("failed creating key %s: %v", keyType, err)
}
req = &logical.Request{
Storage: storage,
Operation: logical.ReadOperation,
Path: "export/bad-export-type/foo",
}
rsp, err := b.HandleRequest(context.Background(), req)
if err == nil {
t.Fatalf("did not error on bad export type got: %v", rsp)
}
if rsp == nil || !rsp.IsError() {
t.Fatalf("response did not contain an error on bad export type got: %v", rsp)
}
if !strings.Contains(rsp.Error().Error(), "invalid export type") {
t.Fatalf("failed with unexpected error: %v", err)
}
}
func TestTransit_Export_KeyVersion_ExportsCorrectVersion(t *testing.T) {
t.Parallel()
verifyExportsCorrectVersion(t, "encryption-key", "aes128-gcm96")
verifyExportsCorrectVersion(t, "encryption-key", "aes256-gcm96")
verifyExportsCorrectVersion(t, "encryption-key", "chacha20-poly1305")
verifyExportsCorrectVersion(t, "encryption-key", "rsa-2048")
verifyExportsCorrectVersion(t, "encryption-key", "rsa-3072")
verifyExportsCorrectVersion(t, "encryption-key", "rsa-4096")
verifyExportsCorrectVersion(t, "signing-key", "ecdsa-p256")
verifyExportsCorrectVersion(t, "signing-key", "ecdsa-p384")
verifyExportsCorrectVersion(t, "signing-key", "ecdsa-p521")
verifyExportsCorrectVersion(t, "signing-key", "ed25519")
verifyExportsCorrectVersion(t, "signing-key", "rsa-2048")
verifyExportsCorrectVersion(t, "signing-key", "rsa-3072")
verifyExportsCorrectVersion(t, "signing-key", "rsa-4096")
verifyExportsCorrectVersion(t, "hmac-key", "aes128-gcm96")
verifyExportsCorrectVersion(t, "hmac-key", "aes256-gcm96")
verifyExportsCorrectVersion(t, "hmac-key", "chacha20-poly1305")
verifyExportsCorrectVersion(t, "hmac-key", "ecdsa-p256")
verifyExportsCorrectVersion(t, "hmac-key", "ecdsa-p384")
verifyExportsCorrectVersion(t, "hmac-key", "ecdsa-p521")
verifyExportsCorrectVersion(t, "hmac-key", "ed25519")
verifyExportsCorrectVersion(t, "hmac-key", "hmac")
verifyExportsCorrectVersion(t, "public-key", "rsa-2048")
verifyExportsCorrectVersion(t, "public-key", "rsa-3072")
verifyExportsCorrectVersion(t, "public-key", "rsa-4096")
verifyExportsCorrectVersion(t, "public-key", "ecdsa-p256")
verifyExportsCorrectVersion(t, "public-key", "ecdsa-p384")
verifyExportsCorrectVersion(t, "public-key", "ecdsa-p521")
verifyExportsCorrectVersion(t, "public-key", "ed25519")
}
func verifyExportsCorrectVersion(t *testing.T, exportType, keyType string) {
b, storage := createBackendWithSysView(t)
// First create a key, v1
req := &logical.Request{
Storage: storage,
Operation: logical.UpdateOperation,
Path: "keys/foo",
}
req.Data = map[string]interface{}{
"exportable": true,
"type": keyType,
}
if keyType == "hmac" {
req.Data["key_size"] = 32
}
_, err := b.HandleRequest(context.Background(), req)
if err != nil {
t.Fatal(err)
}
verifyVersion := func(versionRequest string, expectedVersion int) {
req := &logical.Request{
Storage: storage,
Operation: logical.ReadOperation,
Path: fmt.Sprintf("export/%s/foo/%s", exportType, versionRequest),
}
rsp, err := b.HandleRequest(context.Background(), req)
if err != nil {
t.Fatal(err)
}
typRaw, ok := rsp.Data["type"]
if !ok {
t.Fatal("no type returned from export")
}
typ, ok := typRaw.(string)
if !ok {
t.Fatalf("could not find key type, resp data is %#v", rsp.Data)
}
if typ != keyType {
t.Fatalf("key type mismatch; %q vs %q", typ, keyType)
}
keysRaw, ok := rsp.Data["keys"]
if !ok {
t.Fatal("could not find keys value")
}
keys, ok := keysRaw.(map[string]string)
if !ok {
t.Fatal("could not cast to keys map")
}
if len(keys) != 1 {
t.Fatal("unexpected number of keys found")
}
for k := range keys {
if k != strconv.Itoa(expectedVersion) {
t.Fatalf("expected version %q, received version %q", strconv.Itoa(expectedVersion), k)
}
}
}
verifyVersion("v1", 1)
verifyVersion("1", 1)
verifyVersion("latest", 1)
req.Path = "keys/foo/rotate"
// v2
_, err = b.HandleRequest(context.Background(), req)
if err != nil {
t.Fatal(err)
}
verifyVersion("v1", 1)
verifyVersion("1", 1)
verifyVersion("v2", 2)
verifyVersion("2", 2)
verifyVersion("latest", 2)
// v3
_, err = b.HandleRequest(context.Background(), req)
if err != nil {
t.Fatal(err)
}
verifyVersion("v1", 1)
verifyVersion("1", 1)
verifyVersion("v3", 3)
verifyVersion("3", 3)
verifyVersion("latest", 3)
}
func TestTransit_Export_ValidVersionsOnly(t *testing.T) {
t.Parallel()
b, storage := createBackendWithSysView(t)
// First create a key, v1
req := &logical.Request{
Storage: storage,
Operation: logical.UpdateOperation,
Path: "keys/foo",
}
req.Data = map[string]interface{}{
"exportable": true,
}
_, err := b.HandleRequest(context.Background(), req)
if err != nil {
t.Fatal(err)
}
req.Path = "keys/foo/rotate"
// v2
_, err = b.HandleRequest(context.Background(), req)
if err != nil {
t.Fatal(err)
}
// v3
_, err = b.HandleRequest(context.Background(), req)
if err != nil {
t.Fatal(err)
}
verifyExport := func(validVersions []int) {
req = &logical.Request{
Storage: storage,
Operation: logical.ReadOperation,
Path: "export/encryption-key/foo",
}
rsp, err := b.HandleRequest(context.Background(), req)
if err != nil {
t.Fatal(err)
}
if _, ok := rsp.Data["keys"]; !ok {
t.Error("no keys returned from export")
}
keys, ok := rsp.Data["keys"].(map[string]string)
if !ok {
t.Error("could not cast to keys object")
}
if len(keys) != len(validVersions) {
t.Errorf("expected %d key count, received %d", len(validVersions), len(keys))
}
for _, version := range validVersions {
if _, ok := keys[strconv.Itoa(version)]; !ok {
t.Errorf("expecting to find key version %d, not found", version)
}
}
}
verifyExport([]int{1, 2, 3})
req = &logical.Request{
Storage: storage,
Operation: logical.UpdateOperation,
Path: "keys/foo/config",
}
req.Data = map[string]interface{}{
"min_decryption_version": 3,
}
_, err = b.HandleRequest(context.Background(), req)
if err != nil {
t.Fatal(err)
}
verifyExport([]int{3})
req = &logical.Request{
Storage: storage,
Operation: logical.UpdateOperation,
Path: "keys/foo/config",
}
req.Data = map[string]interface{}{
"min_decryption_version": 2,
}
_, err = b.HandleRequest(context.Background(), req)
if err != nil {
t.Fatal(err)
}
verifyExport([]int{2, 3})
req = &logical.Request{
Storage: storage,
Operation: logical.UpdateOperation,
Path: "keys/foo/rotate",
}
// v4
_, err = b.HandleRequest(context.Background(), req)
if err != nil {
t.Fatal(err)
}
verifyExport([]int{2, 3, 4})
}
func TestTransit_Export_KeysNotMarkedExportable_ReturnsError(t *testing.T) {
t.Parallel()
b, storage := createBackendWithSysView(t)
req := &logical.Request{
Storage: storage,
Operation: logical.UpdateOperation,
Path: "keys/foo",
}
req.Data = map[string]interface{}{
"exportable": false,
}
_, err := b.HandleRequest(context.Background(), req)
if err != nil {
t.Fatal(err)
}
req = &logical.Request{
Storage: storage,
Operation: logical.ReadOperation,
Path: "export/encryption-key/foo",
}
rsp, err := b.HandleRequest(context.Background(), req)
if err != nil {
t.Fatal(err)
}
if !rsp.IsError() {
t.Fatal("Key not marked as exportable but was exported.")
}
}
func TestTransit_Export_SigningDoesNotSupportSigning_ReturnsError(t *testing.T) {
t.Parallel()
b, storage := createBackendWithSysView(t)
req := &logical.Request{
Storage: storage,
Operation: logical.UpdateOperation,
Path: "keys/foo",
}
req.Data = map[string]interface{}{
"exportable": true,
"type": "aes256-gcm96",
}
_, err := b.HandleRequest(context.Background(), req)
if err != nil {
t.Fatal(err)
}
req = &logical.Request{
Storage: storage,
Operation: logical.ReadOperation,
Path: "export/signing-key/foo",
}
_, err = b.HandleRequest(context.Background(), req)
if err == nil {
t.Fatal("Key does not support signing but was exported without error.")
}
}
func TestTransit_Export_EncryptionDoesNotSupportEncryption_ReturnsError(t *testing.T) {
t.Parallel()
testTransit_Export_EncryptionDoesNotSupportEncryption_ReturnsError(t, "ecdsa-p256")
testTransit_Export_EncryptionDoesNotSupportEncryption_ReturnsError(t, "ecdsa-p384")
testTransit_Export_EncryptionDoesNotSupportEncryption_ReturnsError(t, "ecdsa-p521")
testTransit_Export_EncryptionDoesNotSupportEncryption_ReturnsError(t, "ed25519")
}
func testTransit_Export_EncryptionDoesNotSupportEncryption_ReturnsError(t *testing.T, keyType string) {
b, storage := createBackendWithSysView(t)
req := &logical.Request{
Storage: storage,
Operation: logical.UpdateOperation,
Path: "keys/foo",
}
req.Data = map[string]interface{}{
"exportable": true,
"type": keyType,
}
_, err := b.HandleRequest(context.Background(), req)
if err != nil {
t.Fatal(err)
}
req = &logical.Request{
Storage: storage,
Operation: logical.ReadOperation,
Path: "export/encryption-key/foo",
}
_, err = b.HandleRequest(context.Background(), req)
if err == nil {
t.Fatalf("Key %s does not support encryption but was exported without error.", keyType)
}
}
func TestTransit_Export_PublicKeyDoesNotSupportEncryption_ReturnsError(t *testing.T) {
t.Parallel()
testTransit_Export_PublicKeyNotSupported_ReturnsError(t, "chacha20-poly1305")
testTransit_Export_PublicKeyNotSupported_ReturnsError(t, "aes128-gcm96")
testTransit_Export_PublicKeyNotSupported_ReturnsError(t, "aes256-gcm96")
testTransit_Export_PublicKeyNotSupported_ReturnsError(t, "hmac")
}
func testTransit_Export_PublicKeyNotSupported_ReturnsError(t *testing.T, keyType string) {
b, storage := createBackendWithSysView(t)
req := &logical.Request{
Storage: storage,
Operation: logical.UpdateOperation,
Path: "keys/foo",
Data: map[string]interface{}{
"type": keyType,
},
}
if keyType == "hmac" {
req.Data["key_size"] = 32
}
_, err := b.HandleRequest(context.Background(), req)
if err != nil {
t.Fatalf("failed creating key %s: %v", keyType, err)
}
req = &logical.Request{
Storage: storage,
Operation: logical.ReadOperation,
Path: "export/public-key/foo",
}
_, err = b.HandleRequest(context.Background(), req)
if err == nil {
t.Fatalf("Key %s does not support public key exporting but was exported without error.", keyType)
}
if !strings.Contains(err.Error(), fmt.Sprintf("unknown key type %s for export type public-key", keyType)) {
t.Fatalf("unexpected error value for key type: %s: %v", keyType, err)
}
}
func TestTransit_Export_KeysDoesNotExist_ReturnsNotFound(t *testing.T) {
t.Parallel()
b, storage := createBackendWithSysView(t)
req := &logical.Request{
Storage: storage,
Operation: logical.ReadOperation,
Path: "export/encryption-key/foo",
}
rsp, err := b.HandleRequest(context.Background(), req)
if !(rsp == nil && err == nil) {
t.Fatal("Key does not exist but does not return not found")
}
}
func TestTransit_Export_EncryptionKey_DoesNotExportHMACKey(t *testing.T) {
t.Parallel()
b, storage := createBackendWithSysView(t)
req := &logical.Request{
Storage: storage,
Operation: logical.UpdateOperation,
Path: "keys/foo",
}
req.Data = map[string]interface{}{
"exportable": true,
"type": "aes256-gcm96",
}
_, err := b.HandleRequest(context.Background(), req)
if err != nil {
t.Fatal(err)
}
req = &logical.Request{
Storage: storage,
Operation: logical.ReadOperation,
Path: "export/encryption-key/foo",
}
encryptionKeyRsp, err := b.HandleRequest(context.Background(), req)
if err != nil {
t.Fatal(err)
}
req.Path = "export/hmac-key/foo"
hmacKeyRsp, err := b.HandleRequest(context.Background(), req)
if err != nil {
t.Fatal(err)
}
encryptionKeys, ok := encryptionKeyRsp.Data["keys"].(map[string]string)
if !ok {
t.Error("could not cast to keys object")
}
hmacKeys, ok := hmacKeyRsp.Data["keys"].(map[string]string)
if !ok {
t.Error("could not cast to keys object")
}
if len(hmacKeys) != len(encryptionKeys) {
t.Errorf("hmac (%d) and encryption (%d) key count don't match",
len(hmacKeys), len(encryptionKeys))
}
if reflect.DeepEqual(encryptionKeyRsp.Data, hmacKeyRsp.Data) {
t.Fatal("Encryption key data matched hmac key data")
}
}
func TestTransit_Export_CertificateChain(t *testing.T) {
t.Parallel()
generateKeys(t)
// Create Cluster
coreConfig := &vault.CoreConfig{
LogicalBackends: map[string]logical.Factory{
"transit": Factory,
"pki": pki.Factory,
},
}
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
})
cluster.Start()
defer cluster.Cleanup()
cores := cluster.Cores
vault.TestWaitActive(t, cores[0].Core)
client := cores[0].Client
// Mount transit backend
err := client.Sys().Mount("transit", &api.MountInput{
Type: "transit",
})
require.NoError(t, err)
// Mount PKI backend
err = client.Sys().Mount("pki", &api.MountInput{
Type: "pki",
})
require.NoError(t, err)
testTransit_exportCertificateChain(t, client, "rsa-2048")
testTransit_exportCertificateChain(t, client, "rsa-3072")
testTransit_exportCertificateChain(t, client, "rsa-4096")
testTransit_exportCertificateChain(t, client, "ecdsa-p256")
testTransit_exportCertificateChain(t, client, "ecdsa-p384")
testTransit_exportCertificateChain(t, client, "ecdsa-p521")
testTransit_exportCertificateChain(t, client, "ed25519")
}
func testTransit_exportCertificateChain(t *testing.T, apiClient *api.Client, keyType string) {
keyName := fmt.Sprintf("%s", keyType)
issuerName := fmt.Sprintf("%s-issuer", keyType)
// Get key to be imported
privKey := getKey(t, keyType)
privKeyBytes, err := x509.MarshalPKCS8PrivateKey(privKey)
require.NoError(t, err, fmt.Sprintf("failed to marshal private key: %s", err))
// Create CSR
var csrTemplate x509.CertificateRequest
csrTemplate.Subject.CommonName = "example.com"
csrBytes, err := x509.CreateCertificateRequest(cryptoRand.Reader, &csrTemplate, privKey)
require.NoError(t, err, fmt.Sprintf("failed to create CSR: %s", err))
pemCsr := string(pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE REQUEST",
Bytes: csrBytes,
}))
// Generate PKI root
_, err = apiClient.Logical().Write("pki/root/generate/internal", map[string]interface{}{
"issuer_name": issuerName,
"common_name": "PKI Root X1",
})
require.NoError(t, err)
// Create role to be used in the certificate issuing
_, err = apiClient.Logical().Write("pki/roles/example-dot-com", map[string]interface{}{
"issuer_ref": issuerName,
"allowed_domains": "example.com",
"allow_bare_domains": true,
"basic_constraints_valid_for_non_ca": true,
"key_type": "any",
})
require.NoError(t, err)
// Sign the CSR
resp, err := apiClient.Logical().Write("pki/sign/example-dot-com", map[string]interface{}{
"issuer_ref": issuerName,
"csr": pemCsr,
"ttl": "10m",
})
require.NoError(t, err)
require.NotNil(t, resp)
leafCertPEM := resp.Data["certificate"].(string)
// Get wrapping key
resp, err = apiClient.Logical().Read("transit/wrapping_key")
require.NoError(t, err)
require.NotNil(t, resp)
pubWrappingKeyString := strings.TrimSpace(resp.Data["public_key"].(string))
wrappingKeyPemBlock, _ := pem.Decode([]byte(pubWrappingKeyString))
pubWrappingKey, err := x509.ParsePKIXPublicKey(wrappingKeyPemBlock.Bytes)
require.NoError(t, err, "failed to parse wrapping key")
blob := wrapTargetPKCS8ForImport(t, pubWrappingKey.(*rsa.PublicKey), privKeyBytes, "SHA256")
// Import key
_, err = apiClient.Logical().Write(fmt.Sprintf("/transit/keys/%s/import", keyName), map[string]interface{}{
"ciphertext": blob,
"type": keyType,
})
require.NoError(t, err)
// Import cert chain
_, err = apiClient.Logical().Write(fmt.Sprintf("transit/keys/%s/set-certificate", keyName), map[string]interface{}{
"certificate_chain": leafCertPEM,
})
require.NoError(t, err)
// Export cert chain
resp, err = apiClient.Logical().Read(fmt.Sprintf("transit/export/certificate-chain/%s", keyName))
require.NoError(t, err)
require.NotNil(t, resp)
exportedKeys := resp.Data["keys"].(map[string]interface{})
exportedCertChainPEM := exportedKeys["1"].(string)
if exportedCertChainPEM != leafCertPEM {
t.Fatalf("expected exported cert chain to match with imported value")
}
}