command: support custom CAs

This commit is contained in:
Mitchell Hashimoto
2015-04-28 09:36:03 -07:00
parent ae1c71085c
commit bacbf6c082

View File

@@ -3,12 +3,16 @@ package command
import ( import (
"bufio" "bufio"
"crypto/tls" "crypto/tls"
"crypto/x509"
"encoding/pem"
"flag" "flag"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net" "net"
"net/http" "net/http"
"os" "os"
"path/filepath"
"time" "time"
"github.com/hashicorp/vault/api" "github.com/hashicorp/vault/api"
@@ -66,14 +70,21 @@ func (m *Meta) Client() (*api.Client, error) {
// If we need custom TLS configuration, then set it // If we need custom TLS configuration, then set it
if m.flagCACert != "" || m.flagCAPath != "" || m.flagInsecure { if m.flagCACert != "" || m.flagCAPath != "" || m.flagInsecure {
var certPool *x509.CertPool
var err error
if m.flagCACert != "" {
certPool, err = m.loadCACert(m.flagCACert)
} else if m.flagCAPath != "" {
certPool, err = m.loadCAPath(m.flagCAPath)
}
if err != nil {
return nil, fmt.Errorf("Error setting up CA path: %s", err)
}
tlsConfig := &tls.Config{ tlsConfig := &tls.Config{
InsecureSkipVerify: m.flagInsecure, InsecureSkipVerify: m.flagInsecure,
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
} RootCAs: certPool,
if m.flagCACert != "" || m.flagCAPath != "" {
return nil, fmt.Errorf(
"Custom CA certificate will be supported in Vault 0.1.1")
} }
client := *http.DefaultClient client := *http.DefaultClient
@@ -186,3 +197,66 @@ func (m *Meta) TokenHelper() (*token.Helper, error) {
path = token.HelperPath(path) path = token.HelperPath(path)
return &token.Helper{Path: path}, nil return &token.Helper{Path: path}, nil
} }
func (m *Meta) loadCACert(path string) (*x509.CertPool, error) {
certs, err := m.loadCertFromPEM(path)
if err != nil {
return nil, fmt.Errorf("Error loading %s: %s", path, err)
}
result := x509.NewCertPool()
for _, cert := range certs {
result.AddCert(cert)
}
return result, nil
}
func (m *Meta) loadCAPath(path string) (*x509.CertPool, error) {
result := x509.NewCertPool()
fn := func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
certs, err := m.loadCertFromPEM(path)
if err != nil {
return fmt.Errorf("Error loading %s: %s", path, err)
}
for _, cert := range certs {
result.AddCert(cert)
}
return nil
}
return result, filepath.Walk(path, fn)
}
func (m *Meta) loadCertFromPEM(path string) ([]*x509.Certificate, error) {
pemCerts, err := ioutil.ReadFile(path)
if err != nil {
return nil, err
}
certs := make([]*x509.Certificate, 0, 5)
for len(pemCerts) > 0 {
var block *pem.Block
block, pemCerts = pem.Decode(pemCerts)
if block == nil {
break
}
if block.Type != "CERTIFICATE" || len(block.Headers) != 0 {
continue
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return nil, err
}
certs = append(certs, cert)
}
return certs, nil
}