mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-10-31 18:48:08 +00:00 
			
		
		
		
	 3fcb1a67c5
			
		
	
	3fcb1a67c5
	
	
	
		
			
			* add inline cert auth to postres db plugin * handle both sslinline and new TLS plugin fields * refactor PrepareTestContainerWithSSL * add tests for postgres inline TLS fields * changelog * revert back to errwrap since the middleware sanitizing depends on it * enable only setting sslrootcert
		
			
				
	
	
		
			467 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			467 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // Copyright (c) 2019-2021 Jack Christensen
 | |
| 
 | |
| // MIT License
 | |
| 
 | |
| // Permission is hereby granted, free of charge, to any person obtaining
 | |
| // a copy of this software and associated documentation files (the
 | |
| // "Software"), to deal in the Software without restriction, including
 | |
| // without limitation the rights to use, copy, modify, merge, publish,
 | |
| // distribute, sublicense, and/or sell copies of the Software, and to
 | |
| // permit persons to whom the Software is furnished to do so, subject to
 | |
| // the following conditions:
 | |
| 
 | |
| // The above copyright notice and this permission notice shall be
 | |
| // included in all copies or substantial portions of the Software.
 | |
| 
 | |
| // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
 | |
| // EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
 | |
| // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
 | |
| // NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
 | |
| // LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
 | |
| // OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
 | |
| // WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
 | |
| 
 | |
| // Copied from https://github.com/jackc/pgconn/blob/1860f4e57204614f40d05a5c76a43e8d80fde9da/config.go
 | |
| 
 | |
| package connutil
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"crypto/tls"
 | |
| 	"crypto/x509"
 | |
| 	"database/sql"
 | |
| 	"encoding/pem"
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"math"
 | |
| 	"net"
 | |
| 	"net/url"
 | |
| 	"os"
 | |
| 	"strconv"
 | |
| 	"strings"
 | |
| 
 | |
| 	"github.com/hashicorp/vault/sdk/helper/pluginutil"
 | |
| 	"github.com/jackc/pgconn"
 | |
| 	"github.com/jackc/pgx/v4"
 | |
| 	"github.com/jackc/pgx/v4/stdlib"
 | |
| )
 | |
| 
 | |
| // openPostgres parses the connection string and opens a connection to the database.
 | |
| //
 | |
| // If sslinline is set, strips the connection string of all ssl settings and
 | |
| // creates a TLS config based on the settings provided, then uses the
 | |
| // RegisterConnConfig function to create a new connection. This is necessary
 | |
| // because the pgx driver does not support the sslinline parameter and instead
 | |
| // expects to source ssl material from the file system.
 | |
| //
 | |
| // Deprecated: openPostgres will be removed in a future version of the Vault SDK.
 | |
| func openPostgres(driverName, connString string) (*sql.DB, error) {
 | |
| 	if ok, _ := strconv.ParseBool(os.Getenv(pluginutil.PluginUsePostgresSSLInline)); !ok {
 | |
| 		return nil, fmt.Errorf("failed to open postgres connection with deprecated funtion, set feature flag to enable")
 | |
| 	}
 | |
| 
 | |
| 	var options pgconn.ParseConfigOptions
 | |
| 
 | |
| 	settings := make(map[string]string)
 | |
| 	if connString != "" {
 | |
| 		var err error
 | |
| 		// connString may be a database URL or a DSN
 | |
| 		if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") {
 | |
| 			settings, err = parsePostgresURLSettings(connString)
 | |
| 			if err != nil {
 | |
| 				return nil, fmt.Errorf("failed to parse as URL: %w", err)
 | |
| 			}
 | |
| 		} else {
 | |
| 			settings, err = parsePostgresDSNSettings(connString)
 | |
| 			if err != nil {
 | |
| 				return nil, fmt.Errorf("failed to parse as DSN: %w", err)
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	// get the inline flag
 | |
| 	sslInline := settings["sslinline"] == "true"
 | |
| 
 | |
| 	// if sslinline is not set, open a regular connection
 | |
| 	if !sslInline {
 | |
| 		return sql.Open(driverName, connString)
 | |
| 	}
 | |
| 
 | |
| 	// generate a new DSN without the ssl settings
 | |
| 	newConnStr := []string{"sslmode=disable"}
 | |
| 	for k, v := range settings {
 | |
| 		switch k {
 | |
| 		case "sslinline", "sslcert", "sslkey", "sslrootcert", "sslmode":
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		newConnStr = append(newConnStr, fmt.Sprintf("%s='%s'", k, v))
 | |
| 	}
 | |
| 
 | |
| 	// parse the updated config
 | |
| 	config, err := pgx.ParseConfig(strings.Join(newConnStr, " "))
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	// create a TLS config
 | |
| 	fallbacks := []*pgconn.FallbackConfig{}
 | |
| 
 | |
| 	hosts := strings.Split(settings["host"], ",")
 | |
| 	ports := strings.Split(settings["port"], ",")
 | |
| 
 | |
| 	for i, host := range hosts {
 | |
| 		var portStr string
 | |
| 		if i < len(ports) {
 | |
| 			portStr = ports[i]
 | |
| 		} else {
 | |
| 			portStr = ports[0]
 | |
| 		}
 | |
| 
 | |
| 		port, err := parsePort(portStr)
 | |
| 		if err != nil {
 | |
| 			return nil, fmt.Errorf("invalid port: %w", err)
 | |
| 		}
 | |
| 
 | |
| 		var tlsConfigs []*tls.Config
 | |
| 
 | |
| 		// Ignore TLS settings if Unix domain socket like libpq
 | |
| 		if network, _ := pgconn.NetworkAddress(host, port); network == "unix" {
 | |
| 			tlsConfigs = append(tlsConfigs, nil)
 | |
| 		} else {
 | |
| 			var err error
 | |
| 			tlsConfigs, err = configPostgresTLS(settings, host, options)
 | |
| 			if err != nil {
 | |
| 				return nil, fmt.Errorf("failed to configure TLS: %w", err)
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		for _, tlsConfig := range tlsConfigs {
 | |
| 			fallbacks = append(fallbacks, &pgconn.FallbackConfig{
 | |
| 				Host:      host,
 | |
| 				Port:      port,
 | |
| 				TLSConfig: tlsConfig,
 | |
| 			})
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	config.Host = fallbacks[0].Host
 | |
| 	config.Port = fallbacks[0].Port
 | |
| 	config.TLSConfig = fallbacks[0].TLSConfig
 | |
| 	config.Fallbacks = fallbacks[1:]
 | |
| 
 | |
| 	return sql.Open(driverName, stdlib.RegisterConnConfig(config))
 | |
| }
 | |
| 
 | |
| // configPostgresTLS uses libpq's TLS parameters to construct  []*tls.Config. It is
 | |
| // necessary to allow returning multiple TLS configs as sslmode "allow" and
 | |
| // "prefer" allow fallback.
 | |
| //
 | |
| // Copied from https://github.com/jackc/pgconn/blob/1860f4e57204614f40d05a5c76a43e8d80fde9da/config.go
 | |
| // and modified to read ssl material by value instead of file location.
 | |
| func configPostgresTLS(settings map[string]string, thisHost string, parseConfigOptions pgconn.ParseConfigOptions) ([]*tls.Config, error) {
 | |
| 	host := thisHost
 | |
| 	sslmode := settings["sslmode"]
 | |
| 	sslrootcert := settings["sslrootcert"]
 | |
| 	sslcert := settings["sslcert"]
 | |
| 	sslkey := settings["sslkey"]
 | |
| 	sslpassword := settings["sslpassword"]
 | |
| 	sslsni := settings["sslsni"]
 | |
| 
 | |
| 	// Match libpq default behavior
 | |
| 	if sslmode == "" {
 | |
| 		sslmode = "prefer"
 | |
| 	}
 | |
| 	if sslsni == "" {
 | |
| 		sslsni = "1"
 | |
| 	}
 | |
| 
 | |
| 	tlsConfig := &tls.Config{}
 | |
| 
 | |
| 	switch sslmode {
 | |
| 	case "disable":
 | |
| 		return []*tls.Config{nil}, nil
 | |
| 	case "allow", "prefer":
 | |
| 		tlsConfig.InsecureSkipVerify = true
 | |
| 	case "require":
 | |
| 		// According to PostgreSQL documentation, if a root CA file exists,
 | |
| 		// the behavior of sslmode=require should be the same as that of verify-ca
 | |
| 		//
 | |
| 		// See https://www.postgresql.org/docs/12/libpq-ssl.html
 | |
| 		if sslrootcert != "" {
 | |
| 			goto nextCase
 | |
| 		}
 | |
| 		tlsConfig.InsecureSkipVerify = true
 | |
| 		break
 | |
| 	nextCase:
 | |
| 		fallthrough
 | |
| 	case "verify-ca":
 | |
| 		// Don't perform the default certificate verification because it
 | |
| 		// will verify the hostname. Instead, verify the server's
 | |
| 		// certificate chain ourselves in VerifyPeerCertificate and
 | |
| 		// ignore the server name. This emulates libpq's verify-ca
 | |
| 		// behavior.
 | |
| 		//
 | |
| 		// See https://github.com/golang/go/issues/21971#issuecomment-332693931
 | |
| 		// and https://pkg.go.dev/crypto/tls?tab=doc#example-Config-VerifyPeerCertificate
 | |
| 		// for more info.
 | |
| 		tlsConfig.InsecureSkipVerify = true
 | |
| 		tlsConfig.VerifyPeerCertificate = func(certificates [][]byte, _ [][]*x509.Certificate) error {
 | |
| 			certs := make([]*x509.Certificate, len(certificates))
 | |
| 			for i, asn1Data := range certificates {
 | |
| 				cert, err := x509.ParseCertificate(asn1Data)
 | |
| 				if err != nil {
 | |
| 					return errors.New("failed to parse certificate from server: " + err.Error())
 | |
| 				}
 | |
| 				certs[i] = cert
 | |
| 			}
 | |
| 
 | |
| 			// Leave DNSName empty to skip hostname verification.
 | |
| 			opts := x509.VerifyOptions{
 | |
| 				Roots:         tlsConfig.RootCAs,
 | |
| 				Intermediates: x509.NewCertPool(),
 | |
| 			}
 | |
| 			// Skip the first cert because it's the leaf. All others
 | |
| 			// are intermediates.
 | |
| 			for _, cert := range certs[1:] {
 | |
| 				opts.Intermediates.AddCert(cert)
 | |
| 			}
 | |
| 			_, err := certs[0].Verify(opts)
 | |
| 			return err
 | |
| 		}
 | |
| 	case "verify-full":
 | |
| 		tlsConfig.ServerName = host
 | |
| 	default:
 | |
| 		return nil, errors.New("sslmode is invalid")
 | |
| 	}
 | |
| 
 | |
| 	if sslrootcert != "" {
 | |
| 		caCertPool := x509.NewCertPool()
 | |
| 		if !caCertPool.AppendCertsFromPEM([]byte(sslrootcert)) {
 | |
| 			return nil, errors.New("unable to add CA to cert pool")
 | |
| 		}
 | |
| 
 | |
| 		tlsConfig.RootCAs = caCertPool
 | |
| 		tlsConfig.ClientCAs = caCertPool
 | |
| 	}
 | |
| 
 | |
| 	if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") {
 | |
| 		return nil, errors.New(`both "sslcert" and "sslkey" are required`)
 | |
| 	}
 | |
| 
 | |
| 	if sslcert != "" && sslkey != "" {
 | |
| 		block, _ := pem.Decode([]byte(sslkey))
 | |
| 		var pemKey []byte
 | |
| 		var decryptedKey []byte
 | |
| 		var decryptedError error
 | |
| 		// If PEM is encrypted, attempt to decrypt using pass phrase
 | |
| 		if x509.IsEncryptedPEMBlock(block) {
 | |
| 			// Attempt decryption with pass phrase
 | |
| 			// NOTE: only supports RSA (PKCS#1)
 | |
| 			if sslpassword != "" {
 | |
| 				decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword))
 | |
| 			}
 | |
| 			// if sslpassword not provided or has decryption error when use it
 | |
| 			// try to find sslpassword with callback function
 | |
| 			if sslpassword == "" || decryptedError != nil {
 | |
| 				if parseConfigOptions.GetSSLPassword != nil {
 | |
| 					sslpassword = parseConfigOptions.GetSSLPassword(context.Background())
 | |
| 				}
 | |
| 				if sslpassword == "" {
 | |
| 					return nil, fmt.Errorf("unable to find sslpassword")
 | |
| 				}
 | |
| 			}
 | |
| 			decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword))
 | |
| 			// Should we also provide warning for PKCS#1 needed?
 | |
| 			if decryptedError != nil {
 | |
| 				return nil, fmt.Errorf("unable to decrypt key: %w", decryptedError)
 | |
| 			}
 | |
| 
 | |
| 			pemBytes := pem.Block{
 | |
| 				Type:  "RSA PRIVATE KEY",
 | |
| 				Bytes: decryptedKey,
 | |
| 			}
 | |
| 			pemKey = pem.EncodeToMemory(&pemBytes)
 | |
| 		} else {
 | |
| 			pemKey = pem.EncodeToMemory(block)
 | |
| 		}
 | |
| 
 | |
| 		cert, err := tls.X509KeyPair([]byte(sslcert), pemKey)
 | |
| 		if err != nil {
 | |
| 			return nil, fmt.Errorf("unable to load cert: %w", err)
 | |
| 		}
 | |
| 		tlsConfig.Certificates = []tls.Certificate{cert}
 | |
| 	}
 | |
| 
 | |
| 	// Set Server Name Indication (SNI), if enabled by connection parameters.
 | |
| 	// Per RFC 6066, do not set it if the host is a literal IP address (IPv4
 | |
| 	// or IPv6).
 | |
| 	if sslsni == "1" && net.ParseIP(host) == nil {
 | |
| 		tlsConfig.ServerName = host
 | |
| 	}
 | |
| 
 | |
| 	switch sslmode {
 | |
| 	case "allow":
 | |
| 		return []*tls.Config{nil, tlsConfig}, nil
 | |
| 	case "prefer":
 | |
| 		return []*tls.Config{tlsConfig, nil}, nil
 | |
| 	case "require", "verify-ca", "verify-full":
 | |
| 		return []*tls.Config{tlsConfig}, nil
 | |
| 	default:
 | |
| 		panic("BUG: bad sslmode should already have been caught")
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func parsePort(s string) (uint16, error) {
 | |
| 	port, err := strconv.ParseUint(s, 10, 16)
 | |
| 	if err != nil {
 | |
| 		return 0, err
 | |
| 	}
 | |
| 	if port < 1 || port > math.MaxUint16 {
 | |
| 		return 0, errors.New("outside range")
 | |
| 	}
 | |
| 	return uint16(port), nil
 | |
| }
 | |
| 
 | |
| var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1}
 | |
| 
 | |
| func parsePostgresURLSettings(connString string) (map[string]string, error) {
 | |
| 	settings := make(map[string]string)
 | |
| 
 | |
| 	url, err := url.Parse(connString)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	if url.User != nil {
 | |
| 		settings["user"] = url.User.Username()
 | |
| 		if password, present := url.User.Password(); present {
 | |
| 			settings["password"] = password
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	// Handle multiple host:port's in url.Host by splitting them into host,host,host and port,port,port.
 | |
| 	var hosts []string
 | |
| 	var ports []string
 | |
| 	for _, host := range strings.Split(url.Host, ",") {
 | |
| 		if host == "" {
 | |
| 			continue
 | |
| 		}
 | |
| 		if isIPOnly(host) {
 | |
| 			hosts = append(hosts, strings.Trim(host, "[]"))
 | |
| 			continue
 | |
| 		}
 | |
| 		h, p, err := net.SplitHostPort(host)
 | |
| 		if err != nil {
 | |
| 			return nil, fmt.Errorf("failed to split host:port in '%s', err: %w", host, err)
 | |
| 		}
 | |
| 		if h != "" {
 | |
| 			hosts = append(hosts, h)
 | |
| 		}
 | |
| 		if p != "" {
 | |
| 			ports = append(ports, p)
 | |
| 		}
 | |
| 	}
 | |
| 	if len(hosts) > 0 {
 | |
| 		settings["host"] = strings.Join(hosts, ",")
 | |
| 	}
 | |
| 	if len(ports) > 0 {
 | |
| 		settings["port"] = strings.Join(ports, ",")
 | |
| 	}
 | |
| 
 | |
| 	database := strings.TrimLeft(url.Path, "/")
 | |
| 	if database != "" {
 | |
| 		settings["database"] = database
 | |
| 	}
 | |
| 
 | |
| 	nameMap := map[string]string{
 | |
| 		"dbname": "database",
 | |
| 	}
 | |
| 
 | |
| 	for k, v := range url.Query() {
 | |
| 		if k2, present := nameMap[k]; present {
 | |
| 			k = k2
 | |
| 		}
 | |
| 
 | |
| 		settings[k] = v[0]
 | |
| 	}
 | |
| 
 | |
| 	return settings, nil
 | |
| }
 | |
| 
 | |
| func parsePostgresDSNSettings(s string) (map[string]string, error) {
 | |
| 	settings := make(map[string]string)
 | |
| 
 | |
| 	nameMap := map[string]string{
 | |
| 		"dbname": "database",
 | |
| 	}
 | |
| 
 | |
| 	for len(s) > 0 {
 | |
| 		var key, val string
 | |
| 		eqIdx := strings.IndexRune(s, '=')
 | |
| 		if eqIdx < 0 {
 | |
| 			return nil, errors.New("invalid dsn")
 | |
| 		}
 | |
| 
 | |
| 		key = strings.Trim(s[:eqIdx], " \t\n\r\v\f")
 | |
| 		s = strings.TrimLeft(s[eqIdx+1:], " \t\n\r\v\f")
 | |
| 		if len(s) == 0 {
 | |
| 		} else if s[0] != '\'' {
 | |
| 			end := 0
 | |
| 			for ; end < len(s); end++ {
 | |
| 				if asciiSpace[s[end]] == 1 {
 | |
| 					break
 | |
| 				}
 | |
| 				if s[end] == '\\' {
 | |
| 					end++
 | |
| 					if end == len(s) {
 | |
| 						return nil, errors.New("invalid backslash")
 | |
| 					}
 | |
| 				}
 | |
| 			}
 | |
| 			val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1)
 | |
| 			if end == len(s) {
 | |
| 				s = ""
 | |
| 			} else {
 | |
| 				s = s[end+1:]
 | |
| 			}
 | |
| 		} else { // quoted string
 | |
| 			s = s[1:]
 | |
| 			end := 0
 | |
| 			for ; end < len(s); end++ {
 | |
| 				if s[end] == '\'' {
 | |
| 					break
 | |
| 				}
 | |
| 				if s[end] == '\\' {
 | |
| 					end++
 | |
| 				}
 | |
| 			}
 | |
| 			if end == len(s) {
 | |
| 				return nil, errors.New("unterminated quoted string in connection info string")
 | |
| 			}
 | |
| 			val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1)
 | |
| 			if end == len(s) {
 | |
| 				s = ""
 | |
| 			} else {
 | |
| 				s = s[end+1:]
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		if k, ok := nameMap[key]; ok {
 | |
| 			key = k
 | |
| 		}
 | |
| 
 | |
| 		if key == "" {
 | |
| 			return nil, errors.New("invalid dsn")
 | |
| 		}
 | |
| 
 | |
| 		settings[key] = val
 | |
| 	}
 | |
| 
 | |
| 	return settings, nil
 | |
| }
 | |
| 
 | |
| func isIPOnly(host string) bool {
 | |
| 	return net.ParseIP(strings.Trim(host, "[]")) != nil || !strings.Contains(host, ":")
 | |
| }
 |