diff --git a/builtin/logical/pki/backend_test.go b/builtin/logical/pki/backend_test.go index 95de2d2644..7512ac2796 100644 --- a/builtin/logical/pki/backend_test.go +++ b/builtin/logical/pki/backend_test.go @@ -4073,9 +4073,23 @@ func runFullCAChainTest(t *testing.T, keyType string) { } fullChain := resp.Data["ca_chain"].(string) - if strings.Count(fullChain, rootCert) != 1 { - t.Fatalf("expected full chain to contain root certificate; got %v occurrences", strings.Count(fullChain, rootCert)) - } + requireCertInCaChainString(t, fullChain, rootCert, "expected root cert within root cert/ca_chain") + + // Make sure when we issue a leaf certificate we get the full chain back. + resp, err = client.Logical().Write("pki-root/roles/example", map[string]interface{}{ + "allowed_domains": "example.com", + "allow_subdomains": "true", + "max_ttl": "1h", + }) + require.NoError(t, err, "error setting up pki root role: %v", err) + + resp, err = client.Logical().Write("pki-root/issue/example", map[string]interface{}{ + "common_name": "test.example.com", + "ttl": "5m", + }) + require.NoError(t, err, "error issuing certificate from pki root: %v", err) + fullChainArray := resp.Data["ca_chain"].([]interface{}) + requireCertInCaChainArray(t, fullChainArray, rootCert, "expected root cert within root issuance pki-root/issue/example") // Now generate an intermediate at /pki-intermediate, signed by the root. err = client.Sys().Mount("pki-intermediate", &api.MountInput{ @@ -4141,12 +4155,25 @@ func runFullCAChainTest(t *testing.T, keyType string) { require.Equal(t, 0, len(crl.TBSCertList.RevokedCertificates)) fullChain = resp.Data["ca_chain"].(string) - if strings.Count(fullChain, intermediateCert) != 1 { - t.Fatalf("expected full chain to contain intermediate certificate; got %v occurrences", strings.Count(fullChain, intermediateCert)) - } - if strings.Count(fullChain, rootCert) != 1 { - t.Fatalf("expected full chain to contain root certificate; got %v occurrences", strings.Count(fullChain, rootCert)) - } + requireCertInCaChainString(t, fullChain, intermediateCert, "expected full chain to contain intermediate certificate from pki-intermediate/cert/ca_chain") + requireCertInCaChainString(t, fullChain, rootCert, "expected full chain to contain root certificate from pki-intermediate/cert/ca_chain") + + // Make sure when we issue a leaf certificate we get the full chain back. + resp, err = client.Logical().Write("pki-intermediate/roles/example", map[string]interface{}{ + "allowed_domains": "example.com", + "allow_subdomains": "true", + "max_ttl": "1h", + }) + require.NoError(t, err, "error setting up pki intermediate role: %v", err) + + resp, err = client.Logical().Write("pki-intermediate/issue/example", map[string]interface{}{ + "common_name": "test.example.com", + "ttl": "5m", + }) + require.NoError(t, err, "error issuing certificate from pki intermediate: %v", err) + fullChainArray = resp.Data["ca_chain"].([]interface{}) + requireCertInCaChainArray(t, fullChainArray, intermediateCert, "expected full chain to contain intermediate certificate from pki-intermediate/issue/example") + requireCertInCaChainArray(t, fullChainArray, rootCert, "expected full chain to contain root certificate from pki-intermediate/issue/example") // Finally, import this signing cert chain into a new mount to ensure // "external" CAs behave as expected. @@ -4206,6 +4233,23 @@ func runFullCAChainTest(t *testing.T, keyType string) { requireSignedBy(t, issuedCrt, intermediaryCaCert.PublicKey) } +func requireCertInCaChainArray(t *testing.T, chain []interface{}, cert string, msgAndArgs ...interface{}) { + var fullChain string + for _, caCert := range chain { + fullChain = fullChain + "\n" + caCert.(string) + } + + requireCertInCaChainString(t, fullChain, cert, msgAndArgs) +} + +func requireCertInCaChainString(t *testing.T, chain string, cert string, msgAndArgs ...interface{}) { + count := strings.Count(chain, cert) + if count != 1 { + failMsg := fmt.Sprintf("Found %d occurrances of the cert in the provided chain", count) + require.FailNow(t, failMsg, msgAndArgs...) + } +} + type MultiBool int const ( diff --git a/builtin/logical/pki/cert_util.go b/builtin/logical/pki/cert_util.go index 2eaa3ef155..6c2e7b1e2d 100644 --- a/builtin/logical/pki/cert_util.go +++ b/builtin/logical/pki/cert_util.go @@ -1319,6 +1319,7 @@ func generateCreationBundle(b *backend, data *inputBundle, caSign *certutil.CAIn PolicyIdentifiers: data.role.PolicyIdentifiers, BasicConstraintsValidForNonCA: data.role.BasicConstraintsValidForNonCA, NotBeforeDuration: data.role.NotBeforeDuration, + ForceAppendCaChain: caSign != nil, }, SigningBundle: caSign, CSR: csr, diff --git a/builtin/logical/pki/chain_test.go b/builtin/logical/pki/chain_test.go index 95f1715864..04e49c4244 100644 --- a/builtin/logical/pki/chain_test.go +++ b/builtin/logical/pki/chain_test.go @@ -4,8 +4,10 @@ import ( "crypto/x509" "encoding/pem" "fmt" + "strconv" "strings" "testing" + "time" "github.com/hashicorp/vault/api" vaulthttp "github.com/hashicorp/vault/http" @@ -342,6 +344,114 @@ func (c CBUpdateIssuer) Run(t *testing.T, client *api.Client, mount string, know } } +// Issue a leaf +type CBIssueLeaf struct { + Issuer string + Role string +} + +func (c CBIssueLeaf) Run(t *testing.T, client *api.Client, mount string, knownKeys map[string]string, knownCerts map[string]string) { + if len(c.Role) == 0 { + c.Role = "testing" + } + + url := mount + "/roles/" + c.Role + data := make(map[string]interface{}) + data["allow_localhost"] = true + data["ttl"] = "1s" + data["key_type"] = "ec" + + _, err := client.Logical().Write(url, data) + if err != nil { + t.Fatalf("failed to update role (%v): %v / body: %v", c.Role, err, data) + } + + url = mount + "/issuer/" + c.Issuer + "/issue/" + c.Role + data = make(map[string]interface{}) + data["common_name"] = "localhost" + + resp, err := client.Logical().Write(url, data) + if err != nil { + t.Fatalf("failed to issue cert (%v via %v): %v / body: %v", c.Issuer, c.Role, err, data) + } + if resp == nil { + t.Fatalf("failed to issue cert (%v via %v): nil response / body: %v", c.Issuer, c.Role, data) + } +} + +// Stable ordering +func ensureStableOrderingOfChains(t *testing.T, client *api.Client, mount string, knownKeys map[string]string, knownCerts map[string]string) { + // Start by fetching all chains + certChains := make(map[string][]string) + for issuer := range knownCerts { + resp, err := client.Logical().Read(mount + "/issuer/" + issuer) + if err != nil { + t.Fatalf("failed to get chain for issuer (%v): %v", issuer, err) + } + + rawCurrentChain := resp.Data["ca_chain"].([]interface{}) + var currentChain []string + for _, entry := range rawCurrentChain { + currentChain = append(currentChain, strings.TrimSpace(entry.(string))) + } + + certChains[issuer] = currentChain + } + + // Now, generate a bunch of arbitrary roots and validate the chain is + // consistent. + var runs []time.Duration + for i := 0; i < 10; i++ { + name := "stable-order-root-" + strconv.Itoa(i) + step := CBGenerateRoot{ + Key: name, + Name: name, + } + step.Run(t, client, mount, make(map[string]string), make(map[string]string)) + + before := time.Now() + _, err := client.Logical().Delete(mount + "/issuer/" + name) + if err != nil { + t.Fatalf("failed to delete temporary testing issuer %v: %v", name, err) + } + after := time.Now() + elapsed := after.Sub(before) + runs = append(runs, elapsed) + + for issuer := range knownCerts { + resp, err := client.Logical().Read(mount + "/issuer/" + issuer) + if err != nil { + t.Fatalf("failed to get chain for issuer (%v): %v", issuer, err) + } + + rawCurrentChain := resp.Data["ca_chain"].([]interface{}) + for index, entry := range rawCurrentChain { + if strings.TrimSpace(entry.(string)) != certChains[issuer][index] { + t.Fatalf("iteration %d - chain for issuer %v differed at index %d\n%v\nvs\n%v", i, issuer, index, entry, certChains[issuer][index]) + } + } + } + } + + min := runs[0] + max := runs[0] + var avg time.Duration + for _, run := range runs { + if run < min { + min = run + } + + if run > max { + max = run + } + + avg += run + } + avg = avg / time.Duration(len(runs)) + + t.Logf("Chain building run time (deletion) - min: %v / avg: %v / max: %v - entries: %v", min, avg, max, runs) +} + type CBTestStep interface { Run(t *testing.T, client *api.Client, mount string, knownKeys map[string]string, knownCerts map[string]string) } @@ -619,6 +729,14 @@ func Test_CAChainBuilding(t *testing.T) { "full-cycle": "root-old-a,root-old-b,root-old-c,cross-old-new,cross-new-old,root-new-a,root-new-b", }, }, + CBIssueLeaf{Issuer: "root-old-a"}, + CBIssueLeaf{Issuer: "root-old-b"}, + CBIssueLeaf{Issuer: "root-old-c"}, + CBIssueLeaf{Issuer: "cross-old-new"}, + CBIssueLeaf{Issuer: "cross-new-old"}, + CBIssueLeaf{Issuer: "root-new-a"}, + CBIssueLeaf{Issuer: "root-new-b"}, + CBIssueLeaf{Issuer: "inter-a-root-new"}, }, }, { @@ -754,6 +872,15 @@ func Test_CAChainBuilding(t *testing.T) { "d-chained-cross": "root-d,cross-c-d,root-c,cross-b-c,root-b,cross-a-b,root-a", }, }, + CBIssueLeaf{Issuer: "root-a"}, + CBIssueLeaf{Issuer: "cross-a-b"}, + CBIssueLeaf{Issuer: "root-b"}, + CBIssueLeaf{Issuer: "cross-b-c"}, + CBIssueLeaf{Issuer: "root-c"}, + CBIssueLeaf{Issuer: "cross-c-d"}, + CBIssueLeaf{Issuer: "root-d"}, + CBIssueLeaf{Issuer: "cross-d-e"}, + CBIssueLeaf{Issuer: "root-e"}, // Importing the new e->a cross fails because the cycle // it builds is too long. CBGenerateIntermediate{ @@ -777,6 +904,14 @@ func Test_CAChainBuilding(t *testing.T) { CommonName: "root-a", Parent: "root-e", }, + CBIssueLeaf{Issuer: "root-a"}, + CBIssueLeaf{Issuer: "cross-a-b"}, + CBIssueLeaf{Issuer: "root-c"}, + CBIssueLeaf{Issuer: "cross-c-d"}, + CBIssueLeaf{Issuer: "root-d"}, + CBIssueLeaf{Issuer: "cross-d-e"}, + CBIssueLeaf{Issuer: "root-e"}, + CBIssueLeaf{Issuer: "cross-e-a"}, }, }, { @@ -818,6 +953,12 @@ func Test_CAChainBuilding(t *testing.T) { Name: "root-f", CommonName: "root", }, + CBIssueLeaf{Issuer: "root-a"}, + CBIssueLeaf{Issuer: "root-b"}, + CBIssueLeaf{Issuer: "root-c"}, + CBIssueLeaf{Issuer: "root-d"}, + CBIssueLeaf{Issuer: "root-e"}, + CBIssueLeaf{Issuer: "root-f"}, // Seventh reissuance fails. CBGenerateRoot{ Key: "key-root", @@ -834,6 +975,12 @@ func Test_CAChainBuilding(t *testing.T) { Name: "root-g", CommonName: "root", }, + CBIssueLeaf{Issuer: "root-b"}, + CBIssueLeaf{Issuer: "root-c"}, + CBIssueLeaf{Issuer: "root-d"}, + CBIssueLeaf{Issuer: "root-e"}, + CBIssueLeaf{Issuer: "root-f"}, + CBIssueLeaf{Issuer: "root-g"}, }, }, { @@ -965,6 +1112,92 @@ func Test_CAChainBuilding(t *testing.T) { "all-root-old": "root-old-a,root-old-a-reissued,root-old-b,root-old-b-reissued,cross-root-old-b-sig-a,cross-root-old-a-sig-b", }, }, + CBIssueLeaf{Issuer: "root-new-a"}, + CBIssueLeaf{Issuer: "root-new-b"}, + CBIssueLeaf{Issuer: "cross-root-new-b-sig-a"}, + CBIssueLeaf{Issuer: "cross-root-new-a-sig-b"}, + CBIssueLeaf{Issuer: "root-old-a"}, + CBIssueLeaf{Issuer: "root-old-a-reissued"}, + CBIssueLeaf{Issuer: "root-old-b"}, + CBIssueLeaf{Issuer: "root-old-b-reissued"}, + CBIssueLeaf{Issuer: "cross-root-old-b-sig-a"}, + CBIssueLeaf{Issuer: "cross-root-old-a-sig-b"}, + CBIssueLeaf{Issuer: "cross-root-old-a-sig-root-new-a"}, + }, + }, + { + // Test a dual-root of trust chaining example with different + // lengths of chains. + Steps: []CBTestStep{ + CBGenerateRoot{ + Key: "key-root-new", + Name: "root-new", + }, + CBGenerateIntermediate{ + Key: "key-inter-new", + Name: "inter-new", + Parent: "root-new", + }, + CBGenerateRoot{ + Key: "key-root-old", + Name: "root-old", + }, + CBGenerateIntermediate{ + Key: "key-inter-old-a", + Name: "inter-old-a", + Parent: "root-old", + }, + CBGenerateIntermediate{ + Key: "key-inter-old-b", + Name: "inter-old-b", + Parent: "inter-old-a", + }, + // Now generate a cross-signed intermediate to merge these + // two chains. + CBGenerateIntermediate{ + Key: "key-cross-old-new", + Name: "cross-old-new-signed-new", + CommonName: "cross-old-new", + Parent: "inter-new", + }, + CBGenerateIntermediate{ + Key: "key-cross-old-new", + Existing: true, + Name: "cross-old-new-signed-old", + CommonName: "cross-old-new", + Parent: "inter-old-b", + }, + CBGenerateIntermediate{ + Key: "key-leaf-inter", + Name: "leaf-inter", + Parent: "cross-old-new-signed-new", + }, + CBValidateChain{ + Chains: map[string][]string{ + "root-new": {"self"}, + "inter-new": {"self", "root-new"}, + "cross-old-new-signed-new": {"self", "inter-new", "root-new"}, + "root-old": {"self"}, + "inter-old-a": {"self", "root-old"}, + "inter-old-b": {"self", "inter-old-a", "root-old"}, + "cross-old-new-signed-old": {"self", "inter-old-b", "inter-old-a", "root-old"}, + "leaf-inter": {"self", "either-cross", "one-intermediate", "other-inter-or-root", "everything-else", "everything-else", "everything-else", "everything-else"}, + }, + Aliases: map[string]string{ + "either-cross": "cross-old-new-signed-new,cross-old-new-signed-old", + "one-intermediate": "inter-new,inter-old-b", + "other-inter-or-root": "root-new,inter-old-a", + "everything-else": "cross-old-new-signed-new,cross-old-new-signed-old,inter-new,inter-old-b,root-new,inter-old-a,root-old", + }, + }, + CBIssueLeaf{Issuer: "root-new"}, + CBIssueLeaf{Issuer: "inter-new"}, + CBIssueLeaf{Issuer: "root-old"}, + CBIssueLeaf{Issuer: "inter-old-a"}, + CBIssueLeaf{Issuer: "inter-old-b"}, + CBIssueLeaf{Issuer: "cross-old-new-signed-new"}, + CBIssueLeaf{Issuer: "cross-old-new-signed-old"}, + CBIssueLeaf{Issuer: "leaf-inter"}, }, }, } @@ -979,5 +1212,7 @@ func Test_CAChainBuilding(t *testing.T) { testStep.Run(t, client, mount, knownKeys, knownCerts) } + t.Logf("Checking stable ordering of chains...") + ensureStableOrderingOfChains(t, client, mount, knownKeys, knownCerts) } } diff --git a/builtin/logical/pki/chain_util.go b/builtin/logical/pki/chain_util.go index 1d9caaec6f..8774137ffc 100644 --- a/builtin/logical/pki/chain_util.go +++ b/builtin/logical/pki/chain_util.go @@ -81,7 +81,7 @@ func rebuildIssuersChains(ctx context.Context, s logical.Storage, referenceCert // are the same across multiple calls to rebuildIssuersChains with the same // input data. sort.SliceStable(issuers, func(i, j int) bool { - return issuers[i] < issuers[j] + return issuers[i] > issuers[j] }) // We expect each of these maps to be the size of the number of issuers @@ -351,9 +351,33 @@ func rebuildIssuersChains(ctx context.Context, s logical.Storage, referenceCert // ...and add all parents into it. Note that we have to tell if // that parent was already visited or not. if ok && len(parentCerts) > 0 { + // Split children into two categories: roots and intermediates. + // When building a straight-line chain, we want to prefer the + // root (thus, ending the verification) to any cross-signed + // intermediates. If a root is cross-signed, we'll include it's + // cross-signed cert in _its_ chain, thus ignoring our duplicate + // parent here. + // + // Why? When you step from the present node ("issuer") onto one + // of its parents, if you step onto a root, it is a no-op: you + // can still visit all of the neighbors (because any neighbors, + // if they exist, must be cross-signed alternative paths). + // However, if you directly step onto the cross-signed, now you're + // taken in an alternative direction (via its chain), and must + // revisit any roots later. + var roots []issuerID + var intermediates []issuerID + for _, parentCertId := range parentCerts { + if bytes.Equal(issuerIdCertMap[parentCertId].RawSubject, issuerIdCertMap[parentCertId].RawIssuer) { + roots = append(roots, parentCertId) + } else { + intermediates = append(intermediates, parentCertId) + } + } + includedParentCerts := make(map[string]bool, len(parentCerts)+1) includedParentCerts[entry.Certificate] = true - for _, parentCert := range parentCerts { + for _, parentCert := range append(roots, intermediates...) { // See discussion of the algorithm above as to why this is // in the correct order. However, note that we do need to // exclude duplicate certs, hence the map above. @@ -533,7 +557,7 @@ func processAnyCliqueOrCycle( // cross-signing chains; the latter ensures that any cliques can be // strictly bypassed from cycles (but the chain construction later // ensures we pull in the cliques into the cycles). - foundCycles, err := findCyclesNearClique(processedIssuers, issuerIdChildrenMap, allCliqueNodes) + foundCycles, err := findCyclesNearClique(processedIssuers, issuerIdCertMap, issuerIdChildrenMap, allCliqueNodes) if err != nil { // Cycle is too large. return toVisit, err @@ -685,7 +709,7 @@ func processAnyCliqueOrCycle( // Cliques should've been processed by now, if they were necessary // for processable cycles, so ignore them from here to avoid // bloating our search paths. - cycles, err := findAllCyclesWithNode(processedIssuers, issuerIdChildrenMap, issuer, allCliqueNodes) + cycles, err := findAllCyclesWithNode(processedIssuers, issuerIdCertMap, issuerIdChildrenMap, issuer, allCliqueNodes) if err != nil { // To large of cycle. return nil, err @@ -965,6 +989,7 @@ func canonicalizeCycle(cycle []issuerID) []issuerID { func findCyclesNearClique( processedIssuers map[issuerID]bool, + issuerIdCertMap map[issuerID]*x509.Certificate, issuerIdChildrenMap map[issuerID][]issuerID, cliqueNodes []issuerID, ) ([][]issuerID, error) { @@ -993,7 +1018,7 @@ func findCyclesNearClique( } // Find cycles containing this node. - newCycles, err := findAllCyclesWithNode(processedIssuers, issuerIdChildrenMap, child, excludeNodes) + newCycles, err := findAllCyclesWithNode(processedIssuers, issuerIdCertMap, issuerIdChildrenMap, child, excludeNodes) if err != nil { // Found too large of a cycle return nil, err @@ -1009,11 +1034,17 @@ func findCyclesNearClique( excludeNodes = append(excludeNodes, child) } + // Sort cycles from longest->shortest. + sort.SliceStable(knownCycles, func(i, j int) bool { + return len(knownCycles[i]) < len(knownCycles[j]) + }) + return knownCycles, nil } func findAllCyclesWithNode( processedIssuers map[issuerID]bool, + issuerIdCertMap map[issuerID]*x509.Certificate, issuerIdChildrenMap map[issuerID][]issuerID, source issuerID, exclude []issuerID, @@ -1112,6 +1143,7 @@ func findAllCyclesWithNode( if _, ok := pathsTo[child]; !ok { pathsTo[child] = make([][]issuerID, 0) } + for _, path := range pathsTo[current] { if child != source { // We only care about source->source cycles. If this @@ -1160,7 +1192,7 @@ func findAllCyclesWithNode( } } - // Visit this child next. + // Add this child as a candidate to visit next. visitQueue = append(visitQueue, child) // If there's a new parent or we found a new path, then we should @@ -1206,6 +1238,11 @@ func findAllCyclesWithNode( cycles = appendCycleIfNotExisting(cycles, reversed) } + // Sort cycles from longest->shortest. + sort.SliceStable(cycles, func(i, j int) bool { + return len(cycles[i]) > len(cycles[j]) + }) + return cycles, nil } diff --git a/sdk/helper/certutil/helpers.go b/sdk/helper/certutil/helpers.go index 99bed25402..27d056854c 100644 --- a/sdk/helper/certutil/helpers.go +++ b/sdk/helper/certutil/helpers.go @@ -871,16 +871,25 @@ func createCertificate(data *CreationBundle, randReader io.Reader, privateKeyGen } if data.SigningBundle != nil { - if len(data.SigningBundle.Certificate.AuthorityKeyId) > 0 && - !bytes.Equal(data.SigningBundle.Certificate.AuthorityKeyId, data.SigningBundle.Certificate.SubjectKeyId) { + if (len(data.SigningBundle.Certificate.AuthorityKeyId) > 0 && + !bytes.Equal(data.SigningBundle.Certificate.AuthorityKeyId, data.SigningBundle.Certificate.SubjectKeyId)) || + data.Params.ForceAppendCaChain { + var chain []*CertBlock - result.CAChain = []*CertBlock{ - { + signingChain := data.SigningBundle.CAChain + // Some bundles already include the root included in the chain, so don't include it twice. + if len(signingChain) == 0 || !bytes.Equal(signingChain[0].Bytes, data.SigningBundle.CertificateBytes) { + chain = append(chain, &CertBlock{ Certificate: data.SigningBundle.Certificate, Bytes: data.SigningBundle.CertificateBytes, - }, + }) } - result.CAChain = append(result.CAChain, data.SigningBundle.CAChain...) + + if len(signingChain) > 0 { + chain = append(chain, signingChain...) + } + + result.CAChain = chain } } @@ -1158,7 +1167,7 @@ func signCertificate(data *CreationBundle, randReader io.Reader) (*ParsedCertBun return nil, errutil.InternalError{Err: fmt.Sprintf("unable to parse created certificate: %s", err)} } - result.CAChain = data.SigningBundle.GetCAChain() + result.CAChain = data.SigningBundle.GetFullChain() return result, nil } diff --git a/sdk/helper/certutil/types.go b/sdk/helper/certutil/types.go index 7f36c7ab5e..3557f9b0e3 100644 --- a/sdk/helper/certutil/types.go +++ b/sdk/helper/certutil/types.go @@ -704,13 +704,7 @@ func (b *CAInfoBundle) GetCAChain() []*CertBlock { (len(b.Certificate.AuthorityKeyId) == 0 && !bytes.Equal(b.Certificate.RawIssuer, b.Certificate.RawSubject)) { - chain = append(chain, &CertBlock{ - Certificate: b.Certificate, - Bytes: b.CertificateBytes, - }) - if b.CAChain != nil && len(b.CAChain) > 0 { - chain = append(chain, b.CAChain...) - } + chain = b.GetFullChain() } return chain @@ -771,6 +765,7 @@ type CreationParameters struct { PolicyIdentifiers []string BasicConstraintsValidForNonCA bool SignatureBits int + ForceAppendCaChain bool // Only used when signing a CA cert UseCSRValues bool