Add tests for using raw CSR values

This commit is contained in:
Jeff Mitchell
2015-10-14 11:46:01 -04:00
parent a763391615
commit 4de2060a96

View File

@@ -13,7 +13,9 @@ import (
"fmt"
"math"
mathrand "math/rand"
"net"
"os"
"reflect"
"strings"
"testing"
"time"
@@ -54,7 +56,7 @@ func TestBackend_RSAKey(t *testing.T) {
intdata := map[string]interface{}{}
reqdata := map[string]interface{}{}
testCase.Steps = append(testCase.Steps, generateCASteps(t, rsaCACert, rsaCAKey, intdata, reqdata)...)
testCase.Steps = append(testCase.Steps, generateCATestingSteps(t, rsaCACert, rsaCAKey, intdata, reqdata)...)
logicaltest.Test(t, testCase)
}
@@ -84,7 +86,35 @@ func TestBackend_ECKey(t *testing.T) {
intdata := map[string]interface{}{}
reqdata := map[string]interface{}{}
testCase.Steps = append(testCase.Steps, generateCASteps(t, ecCACert, ecCAKey, intdata, reqdata)...)
testCase.Steps = append(testCase.Steps, generateCATestingSteps(t, ecCACert, ecCAKey, intdata, reqdata)...)
logicaltest.Test(t, testCase)
}
func TestBackend_CSRValues(t *testing.T) {
defaultLeaseTTLVal := time.Hour * 24
maxLeaseTTLVal := time.Hour * 24 * 30
b, err := Factory(&logical.BackendConfig{
Logger: nil,
System: &logical.StaticSystemView{
DefaultLeaseTTLVal: defaultLeaseTTLVal,
MaxLeaseTTLVal: maxLeaseTTLVal,
},
})
if err != nil {
t.Fatalf("Unable to create backend: %s", err)
}
testCase := logicaltest.TestCase{
Backend: b,
Steps: []logicaltest.TestStep{},
}
stepCount += len(testCase.Steps)
intdata := map[string]interface{}{}
reqdata := map[string]interface{}{}
testCase.Steps = append(testCase.Steps, generateCSRSteps(t, ecCACert, ecCAKey, intdata, reqdata)...)
logicaltest.Test(t, testCase)
}
@@ -252,9 +282,92 @@ func checkCertsAndPrivateKey(keyType string, key crypto.Signer, usage certUsage,
return parsedCertBundle, nil
}
func generateCSRSteps(t *testing.T, caCert, caKey string, intdata, reqdata map[string]interface{}) []logicaltest.TestStep {
csrTemplate := x509.CertificateRequest{
Subject: pkix.Name{
Country: []string{"MyCountry"},
PostalCode: []string{"MyPostalCode"},
SerialNumber: "MySerialNumber",
CommonName: "my@example.com",
},
DNSNames: []string{
"name1.example.com",
"name2.example.com",
"name3.example.com",
},
EmailAddresses: []string{
"name1@example.com",
"name2@example.com",
"name3@example.com",
},
IPAddresses: []net.IP{
net.ParseIP("::ff:1:2:3:4"),
net.ParseIP("::ff:5:6:7:8"),
},
}
priv, _ := rsa.GenerateKey(rand.Reader, 2048)
csr, _ := x509.CreateCertificateRequest(rand.Reader, &csrTemplate, priv)
csrPem := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE REQUEST",
Bytes: csr,
})
ret := []logicaltest.TestStep{
logicaltest.TestStep{
Operation: logical.WriteOperation,
Path: "config/ca/generate/root/exported",
Data: map[string]interface{}{
"common_name": "Root Cert",
"ttl": "180h",
},
},
logicaltest.TestStep{
Operation: logical.WriteOperation,
Path: "config/ca/sign",
Data: map[string]interface{}{
"use_csr_values": true,
"csr": string(csrPem),
"format": "der",
},
Check: func(resp *logical.Response) error {
certString := resp.Data["certificate"].(string)
if certString == "" {
return fmt.Errorf("no certificate returned")
}
certBytes, _ := base64.StdEncoding.DecodeString(certString)
certs, err := x509.ParseCertificates(certBytes)
if err != nil {
return fmt.Errorf("returned cert cannot be parsed: %v", err)
}
if len(certs) != 1 {
return fmt.Errorf("unexpected returned length of certificates: %d", len(certs))
}
cert := certs[0]
// We need to set these as they are filled in with unparsed values in the final cert
csrTemplate.Subject.Names = cert.Subject.Names
csrTemplate.Subject.ExtraNames = cert.Subject.ExtraNames
switch {
case !reflect.DeepEqual(cert.Subject, csrTemplate.Subject):
return fmt.Errorf("cert subject\n%#v\ndoes not match csr subject\n%#v\n", cert.Subject, csrTemplate.Subject)
case !reflect.DeepEqual(cert.DNSNames, csrTemplate.DNSNames):
return fmt.Errorf("cert dns names\n%#v\ndoes not match csr dns names\n%#v\n", cert.DNSNames, csrTemplate.DNSNames)
case !reflect.DeepEqual(cert.EmailAddresses, csrTemplate.EmailAddresses):
return fmt.Errorf("cert email addresses\n%#v\ndoes not match csr email addresses\n%#v\n", cert.EmailAddresses, csrTemplate.EmailAddresses)
case !reflect.DeepEqual(cert.IPAddresses, csrTemplate.IPAddresses):
return fmt.Errorf("cert ip addresses\n%#v\ndoes not match csr ip addresses\n%#v\n", cert.IPAddresses, csrTemplate.IPAddresses)
}
return nil
},
},
}
return ret
}
// Generates steps to test out CA configuration -- certificates + CRL expiry,
// and ensure that the certificates are readable after storing them
func generateCASteps(t *testing.T, caCert, caKey string, intdata, reqdata map[string]interface{}) []logicaltest.TestStep {
func generateCATestingSteps(t *testing.T, caCert, caKey string, intdata, reqdata map[string]interface{}) []logicaltest.TestStep {
ret := []logicaltest.TestStep{
logicaltest.TestStep{
Operation: logical.WriteOperation,
@@ -396,7 +509,6 @@ func generateCASteps(t *testing.T, caCert, caKey string, intdata, reqdata map[st
Operation: logical.WriteOperation,
Path: "config/ca/generate/root/exported",
Data: map[string]interface{}{
"pki_address": "http://example.com/v1/mnt",
"common_name": "Root Cert",
"ttl": "180h",
},
@@ -426,7 +538,6 @@ func generateCASteps(t *testing.T, caCert, caKey string, intdata, reqdata map[st
delete(reqdata, "pem_bundle")
delete(reqdata, "ttl")
reqdata["csr"] = intdata["intermediatecsr"].(string)
reqdata["pki_address"] = "http://example.com/v1/mnt"
reqdata["common_name"] = "Intermediate Cert"
reqdata["ttl"] = "90h"
return nil
@@ -440,7 +551,6 @@ func generateCASteps(t *testing.T, caCert, caKey string, intdata, reqdata map[st
Check: func(resp *logical.Response) error {
intdata["intermediatecert"] = resp.Data["certificate"].(string)
delete(reqdata, "csr")
delete(reqdata, "pki_address")
delete(reqdata, "common_name")
delete(reqdata, "ttl")
reqdata["serial_number"] = resp.Data["serial_number"].(string)
@@ -482,7 +592,6 @@ func generateCASteps(t *testing.T, caCert, caKey string, intdata, reqdata map[st
Operation: logical.WriteOperation,
Path: "config/ca/generate/root/exported",
Data: map[string]interface{}{
"pki_address": "http://example.com/v1/mnt",
"common_name": "Root Cert",
"ttl": "180h",
"key_type": "ec",
@@ -540,7 +649,6 @@ func generateCASteps(t *testing.T, caCert, caKey string, intdata, reqdata map[st
delete(reqdata, "pem_bundle")
delete(reqdata, "ttl")
reqdata["csr"] = intdata["intermediatecsr"].(string)
reqdata["pki_address"] = "http://example.com/v1/mnt"
reqdata["common_name"] = "Intermediate Cert"
reqdata["ttl"] = "90h"
return nil
@@ -554,7 +662,6 @@ func generateCASteps(t *testing.T, caCert, caKey string, intdata, reqdata map[st
Check: func(resp *logical.Response) error {
intdata["intermediatecert"] = resp.Data["certificate"].(string)
delete(reqdata, "csr")
delete(reqdata, "pki_address")
delete(reqdata, "common_name")
delete(reqdata, "ttl")
reqdata["serial_number"] = resp.Data["serial_number"].(string)