mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-11-02 03:27:54 +00:00
Add tests for using raw CSR values
This commit is contained in:
@@ -13,7 +13,9 @@ import (
|
||||
"fmt"
|
||||
"math"
|
||||
mathrand "math/rand"
|
||||
"net"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -54,7 +56,7 @@ func TestBackend_RSAKey(t *testing.T) {
|
||||
|
||||
intdata := map[string]interface{}{}
|
||||
reqdata := map[string]interface{}{}
|
||||
testCase.Steps = append(testCase.Steps, generateCASteps(t, rsaCACert, rsaCAKey, intdata, reqdata)...)
|
||||
testCase.Steps = append(testCase.Steps, generateCATestingSteps(t, rsaCACert, rsaCAKey, intdata, reqdata)...)
|
||||
|
||||
logicaltest.Test(t, testCase)
|
||||
}
|
||||
@@ -84,7 +86,35 @@ func TestBackend_ECKey(t *testing.T) {
|
||||
|
||||
intdata := map[string]interface{}{}
|
||||
reqdata := map[string]interface{}{}
|
||||
testCase.Steps = append(testCase.Steps, generateCASteps(t, ecCACert, ecCAKey, intdata, reqdata)...)
|
||||
testCase.Steps = append(testCase.Steps, generateCATestingSteps(t, ecCACert, ecCAKey, intdata, reqdata)...)
|
||||
|
||||
logicaltest.Test(t, testCase)
|
||||
}
|
||||
|
||||
func TestBackend_CSRValues(t *testing.T) {
|
||||
defaultLeaseTTLVal := time.Hour * 24
|
||||
maxLeaseTTLVal := time.Hour * 24 * 30
|
||||
b, err := Factory(&logical.BackendConfig{
|
||||
Logger: nil,
|
||||
System: &logical.StaticSystemView{
|
||||
DefaultLeaseTTLVal: defaultLeaseTTLVal,
|
||||
MaxLeaseTTLVal: maxLeaseTTLVal,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to create backend: %s", err)
|
||||
}
|
||||
|
||||
testCase := logicaltest.TestCase{
|
||||
Backend: b,
|
||||
Steps: []logicaltest.TestStep{},
|
||||
}
|
||||
|
||||
stepCount += len(testCase.Steps)
|
||||
|
||||
intdata := map[string]interface{}{}
|
||||
reqdata := map[string]interface{}{}
|
||||
testCase.Steps = append(testCase.Steps, generateCSRSteps(t, ecCACert, ecCAKey, intdata, reqdata)...)
|
||||
|
||||
logicaltest.Test(t, testCase)
|
||||
}
|
||||
@@ -252,9 +282,92 @@ func checkCertsAndPrivateKey(keyType string, key crypto.Signer, usage certUsage,
|
||||
return parsedCertBundle, nil
|
||||
}
|
||||
|
||||
func generateCSRSteps(t *testing.T, caCert, caKey string, intdata, reqdata map[string]interface{}) []logicaltest.TestStep {
|
||||
csrTemplate := x509.CertificateRequest{
|
||||
Subject: pkix.Name{
|
||||
Country: []string{"MyCountry"},
|
||||
PostalCode: []string{"MyPostalCode"},
|
||||
SerialNumber: "MySerialNumber",
|
||||
CommonName: "my@example.com",
|
||||
},
|
||||
DNSNames: []string{
|
||||
"name1.example.com",
|
||||
"name2.example.com",
|
||||
"name3.example.com",
|
||||
},
|
||||
EmailAddresses: []string{
|
||||
"name1@example.com",
|
||||
"name2@example.com",
|
||||
"name3@example.com",
|
||||
},
|
||||
IPAddresses: []net.IP{
|
||||
net.ParseIP("::ff:1:2:3:4"),
|
||||
net.ParseIP("::ff:5:6:7:8"),
|
||||
},
|
||||
}
|
||||
|
||||
priv, _ := rsa.GenerateKey(rand.Reader, 2048)
|
||||
csr, _ := x509.CreateCertificateRequest(rand.Reader, &csrTemplate, priv)
|
||||
csrPem := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE REQUEST",
|
||||
Bytes: csr,
|
||||
})
|
||||
|
||||
ret := []logicaltest.TestStep{
|
||||
logicaltest.TestStep{
|
||||
Operation: logical.WriteOperation,
|
||||
Path: "config/ca/generate/root/exported",
|
||||
Data: map[string]interface{}{
|
||||
"common_name": "Root Cert",
|
||||
"ttl": "180h",
|
||||
},
|
||||
},
|
||||
logicaltest.TestStep{
|
||||
Operation: logical.WriteOperation,
|
||||
Path: "config/ca/sign",
|
||||
Data: map[string]interface{}{
|
||||
"use_csr_values": true,
|
||||
"csr": string(csrPem),
|
||||
"format": "der",
|
||||
},
|
||||
Check: func(resp *logical.Response) error {
|
||||
certString := resp.Data["certificate"].(string)
|
||||
if certString == "" {
|
||||
return fmt.Errorf("no certificate returned")
|
||||
}
|
||||
certBytes, _ := base64.StdEncoding.DecodeString(certString)
|
||||
certs, err := x509.ParseCertificates(certBytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("returned cert cannot be parsed: %v", err)
|
||||
}
|
||||
if len(certs) != 1 {
|
||||
return fmt.Errorf("unexpected returned length of certificates: %d", len(certs))
|
||||
}
|
||||
cert := certs[0]
|
||||
|
||||
// We need to set these as they are filled in with unparsed values in the final cert
|
||||
csrTemplate.Subject.Names = cert.Subject.Names
|
||||
csrTemplate.Subject.ExtraNames = cert.Subject.ExtraNames
|
||||
switch {
|
||||
case !reflect.DeepEqual(cert.Subject, csrTemplate.Subject):
|
||||
return fmt.Errorf("cert subject\n%#v\ndoes not match csr subject\n%#v\n", cert.Subject, csrTemplate.Subject)
|
||||
case !reflect.DeepEqual(cert.DNSNames, csrTemplate.DNSNames):
|
||||
return fmt.Errorf("cert dns names\n%#v\ndoes not match csr dns names\n%#v\n", cert.DNSNames, csrTemplate.DNSNames)
|
||||
case !reflect.DeepEqual(cert.EmailAddresses, csrTemplate.EmailAddresses):
|
||||
return fmt.Errorf("cert email addresses\n%#v\ndoes not match csr email addresses\n%#v\n", cert.EmailAddresses, csrTemplate.EmailAddresses)
|
||||
case !reflect.DeepEqual(cert.IPAddresses, csrTemplate.IPAddresses):
|
||||
return fmt.Errorf("cert ip addresses\n%#v\ndoes not match csr ip addresses\n%#v\n", cert.IPAddresses, csrTemplate.IPAddresses)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// Generates steps to test out CA configuration -- certificates + CRL expiry,
|
||||
// and ensure that the certificates are readable after storing them
|
||||
func generateCASteps(t *testing.T, caCert, caKey string, intdata, reqdata map[string]interface{}) []logicaltest.TestStep {
|
||||
func generateCATestingSteps(t *testing.T, caCert, caKey string, intdata, reqdata map[string]interface{}) []logicaltest.TestStep {
|
||||
ret := []logicaltest.TestStep{
|
||||
logicaltest.TestStep{
|
||||
Operation: logical.WriteOperation,
|
||||
@@ -396,7 +509,6 @@ func generateCASteps(t *testing.T, caCert, caKey string, intdata, reqdata map[st
|
||||
Operation: logical.WriteOperation,
|
||||
Path: "config/ca/generate/root/exported",
|
||||
Data: map[string]interface{}{
|
||||
"pki_address": "http://example.com/v1/mnt",
|
||||
"common_name": "Root Cert",
|
||||
"ttl": "180h",
|
||||
},
|
||||
@@ -426,7 +538,6 @@ func generateCASteps(t *testing.T, caCert, caKey string, intdata, reqdata map[st
|
||||
delete(reqdata, "pem_bundle")
|
||||
delete(reqdata, "ttl")
|
||||
reqdata["csr"] = intdata["intermediatecsr"].(string)
|
||||
reqdata["pki_address"] = "http://example.com/v1/mnt"
|
||||
reqdata["common_name"] = "Intermediate Cert"
|
||||
reqdata["ttl"] = "90h"
|
||||
return nil
|
||||
@@ -440,7 +551,6 @@ func generateCASteps(t *testing.T, caCert, caKey string, intdata, reqdata map[st
|
||||
Check: func(resp *logical.Response) error {
|
||||
intdata["intermediatecert"] = resp.Data["certificate"].(string)
|
||||
delete(reqdata, "csr")
|
||||
delete(reqdata, "pki_address")
|
||||
delete(reqdata, "common_name")
|
||||
delete(reqdata, "ttl")
|
||||
reqdata["serial_number"] = resp.Data["serial_number"].(string)
|
||||
@@ -482,7 +592,6 @@ func generateCASteps(t *testing.T, caCert, caKey string, intdata, reqdata map[st
|
||||
Operation: logical.WriteOperation,
|
||||
Path: "config/ca/generate/root/exported",
|
||||
Data: map[string]interface{}{
|
||||
"pki_address": "http://example.com/v1/mnt",
|
||||
"common_name": "Root Cert",
|
||||
"ttl": "180h",
|
||||
"key_type": "ec",
|
||||
@@ -540,7 +649,6 @@ func generateCASteps(t *testing.T, caCert, caKey string, intdata, reqdata map[st
|
||||
delete(reqdata, "pem_bundle")
|
||||
delete(reqdata, "ttl")
|
||||
reqdata["csr"] = intdata["intermediatecsr"].(string)
|
||||
reqdata["pki_address"] = "http://example.com/v1/mnt"
|
||||
reqdata["common_name"] = "Intermediate Cert"
|
||||
reqdata["ttl"] = "90h"
|
||||
return nil
|
||||
@@ -554,7 +662,6 @@ func generateCASteps(t *testing.T, caCert, caKey string, intdata, reqdata map[st
|
||||
Check: func(resp *logical.Response) error {
|
||||
intdata["intermediatecert"] = resp.Data["certificate"].(string)
|
||||
delete(reqdata, "csr")
|
||||
delete(reqdata, "pki_address")
|
||||
delete(reqdata, "common_name")
|
||||
delete(reqdata, "ttl")
|
||||
reqdata["serial_number"] = resp.Data["serial_number"].(string)
|
||||
|
||||
Reference in New Issue
Block a user