Initial Atlas listener implementation

This commit is contained in:
Jeff Mitchell
2016-06-02 12:40:25 -04:00
parent 29c78f6512
commit d32283ba49
9 changed files with 132 additions and 20 deletions

View File

@@ -44,6 +44,8 @@ type ServerCommand struct {
meta.Meta meta.Meta
logger *log.Logger
ReloadFuncs map[string][]server.ReloadFunc ReloadFuncs map[string][]server.ReloadFunc
} }
@@ -136,7 +138,7 @@ func (c *ServerCommand) Run(args []string) int {
// Create a logger. We wrap it in a gated writer so that it doesn't // Create a logger. We wrap it in a gated writer so that it doesn't
// start logging too early. // start logging too early.
logGate := &gatedwriter.Writer{Writer: os.Stderr} logGate := &gatedwriter.Writer{Writer: os.Stderr}
logger := log.New(&logutils.LevelFilter{ c.logger = log.New(&logutils.LevelFilter{
Levels: []logutils.LogLevel{ Levels: []logutils.LogLevel{
"TRACE", "DEBUG", "INFO", "WARN", "ERR"}, "TRACE", "DEBUG", "INFO", "WARN", "ERR"},
MinLevel: logutils.LogLevel(strings.ToUpper(logLevel)), MinLevel: logutils.LogLevel(strings.ToUpper(logLevel)),
@@ -150,7 +152,7 @@ func (c *ServerCommand) Run(args []string) int {
// Initialize the backend // Initialize the backend
backend, err := physical.NewBackend( backend, err := physical.NewBackend(
config.Backend.Type, logger, config.Backend.Config) config.Backend.Type, c.logger, config.Backend.Config)
if err != nil { if err != nil {
c.Ui.Error(fmt.Sprintf( c.Ui.Error(fmt.Sprintf(
"Error initializing backend of type %s: %s", "Error initializing backend of type %s: %s",
@@ -179,7 +181,7 @@ func (c *ServerCommand) Run(args []string) int {
AuditBackends: c.AuditBackends, AuditBackends: c.AuditBackends,
CredentialBackends: c.CredentialBackends, CredentialBackends: c.CredentialBackends,
LogicalBackends: c.LogicalBackends, LogicalBackends: c.LogicalBackends,
Logger: logger, Logger: c.logger,
DisableCache: config.DisableCache, DisableCache: config.DisableCache,
DisableMlock: config.DisableMlock, DisableMlock: config.DisableMlock,
MaxLeaseTTL: config.MaxLeaseTTL, MaxLeaseTTL: config.MaxLeaseTTL,
@@ -190,7 +192,7 @@ func (c *ServerCommand) Run(args []string) int {
var ok bool var ok bool
if config.HABackend != nil { if config.HABackend != nil {
habackend, err := physical.NewBackend( habackend, err := physical.NewBackend(
config.HABackend.Type, logger, config.HABackend.Config) config.HABackend.Type, c.logger, config.HABackend.Config)
if err != nil { if err != nil {
c.Ui.Error(fmt.Sprintf( c.Ui.Error(fmt.Sprintf(
"Error initializing backend of type %s: %s", "Error initializing backend of type %s: %s",
@@ -322,7 +324,7 @@ func (c *ServerCommand) Run(args []string) int {
// Initialize the listeners // Initialize the listeners
lns := make([]net.Listener, 0, len(config.Listeners)) lns := make([]net.Listener, 0, len(config.Listeners))
for i, lnConfig := range config.Listeners { for i, lnConfig := range config.Listeners {
ln, props, reloadFunc, err := server.NewListener(lnConfig.Type, lnConfig.Config) ln, props, reloadFunc, err := server.NewListener(lnConfig.Type, lnConfig.Config, logGate)
if err != nil { if err != nil {
c.Ui.Error(fmt.Sprintf( c.Ui.Error(fmt.Sprintf(
"Error initializing listener of type %s: %s", "Error initializing listener of type %s: %s",
@@ -351,6 +353,13 @@ func (c *ServerCommand) Run(args []string) int {
} }
} }
// Make sure we close all listeners from this point on
defer func() {
for _, ln := range lns {
ln.Close()
}
}()
infoKeys = append(infoKeys, "version") infoKeys = append(infoKeys, "version")
info["version"] = version.GetVersion().String() info["version"] = version.GetVersion().String()
@@ -368,9 +377,6 @@ func (c *ServerCommand) Run(args []string) int {
c.Ui.Output("") c.Ui.Output("")
if verifyOnly { if verifyOnly {
for _, listener := range lns {
listener.Close()
}
return 0 return 0
} }
@@ -410,10 +416,6 @@ func (c *ServerCommand) Run(args []string) int {
} }
} }
for _, listener := range lns {
listener.Close()
}
return 0 return 0
} }

View File

@@ -200,6 +200,7 @@ func ParseConfig(d string) (*Config, error) {
} }
valid := []string{ valid := []string{
"atlas",
"backend", "backend",
"ha_backend", "ha_backend",
"listener", "listener",
@@ -414,6 +415,8 @@ func parseHABackends(result *Config, list *ast.ObjectList) error {
} }
func parseListeners(result *Config, list *ast.ObjectList) error { func parseListeners(result *Config, list *ast.ObjectList) error {
var foundAtlas bool
listeners := make([]*Listener, 0, len(list.Items)) listeners := make([]*Listener, 0, len(list.Items))
for _, item := range list.Items { for _, item := range list.Items {
key := "listener" key := "listener"
@@ -423,10 +426,13 @@ func parseListeners(result *Config, list *ast.ObjectList) error {
valid := []string{ valid := []string{
"address", "address",
"endpoint",
"infrastructure",
"tls_disable", "tls_disable",
"tls_cert_file", "tls_cert_file",
"tls_key_file", "tls_key_file",
"tls_min_version", "tls_min_version",
"token",
} }
if err := checkHCLKeys(item.Val, valid); err != nil { if err := checkHCLKeys(item.Val, valid); err != nil {
return multierror.Prefix(err, fmt.Sprintf("listeners.%s:", key)) return multierror.Prefix(err, fmt.Sprintf("listeners.%s:", key))
@@ -437,8 +443,24 @@ func parseListeners(result *Config, list *ast.ObjectList) error {
return multierror.Prefix(err, fmt.Sprintf("listeners.%s:", key)) return multierror.Prefix(err, fmt.Sprintf("listeners.%s:", key))
} }
lnType := strings.ToLower(key)
if lnType == "atlas" {
if foundAtlas {
return multierror.Prefix(fmt.Errorf("only one listener of type 'atlas' is permitted"), fmt.Sprintf("listeners.%s", key))
} else {
foundAtlas = true
if m["token"] == "" {
return multierror.Prefix(fmt.Errorf("'token' must be specified for an Atlas listener"), fmt.Sprintf("listeners.%s", key))
}
if m["infrastructure"] == "" {
return multierror.Prefix(fmt.Errorf("'infrastructure' must be specified for an Atlas listener"), fmt.Sprintf("listeners.%s", key))
}
}
}
listeners = append(listeners, &Listener{ listeners = append(listeners, &Listener{
Type: strings.ToLower(key), Type: lnType,
Config: m, Config: m,
}) })
} }

View File

@@ -21,6 +21,14 @@ func TestLoadConfigFile(t *testing.T) {
"address": "127.0.0.1:443", "address": "127.0.0.1:443",
}, },
}, },
&Listener{
Type: "atlas",
Config: map[string]string{
Token: "foobar",
Infrastructure: "foo/bar",
Endpoint: "https://foo.bar:1111",
},
},
}, },
Backend: &Backend{ Backend: &Backend{
@@ -72,6 +80,14 @@ func TestLoadConfigFile_json(t *testing.T) {
"address": "127.0.0.1:443", "address": "127.0.0.1:443",
}, },
}, },
&Listener{
Type: "atlas",
Config: map[string]string{
Token: "foobar",
Infrastructure: "foo/bar",
Endpoint: "https://foo.bar:1111",
},
},
}, },
Backend: &Backend{ Backend: &Backend{

View File

@@ -6,17 +6,19 @@ import (
_ "crypto/sha512" _ "crypto/sha512"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"io"
"net" "net"
"strconv" "strconv"
"sync" "sync"
) )
// ListenerFactory is the factory function to create a listener. // ListenerFactory is the factory function to create a listener.
type ListenerFactory func(map[string]string) (net.Listener, map[string]string, ReloadFunc, error) type ListenerFactory func(map[string]string, io.Writer) (net.Listener, map[string]string, ReloadFunc, error)
// BuiltinListeners is the list of built-in listener types. // BuiltinListeners is the list of built-in listener types.
var BuiltinListeners = map[string]ListenerFactory{ var BuiltinListeners = map[string]ListenerFactory{
"tcp": tcpListenerFactory, "tcp": tcpListenerFactory,
"atlas": atlasListenerFactory,
} }
// tlsLookup maps the tls_min_version configuration to the internal value // tlsLookup maps the tls_min_version configuration to the internal value
@@ -28,13 +30,13 @@ var tlsLookup = map[string]uint16{
// NewListener creates a new listener of the given type with the given // NewListener creates a new listener of the given type with the given
// configuration. The type is looked up in the BuiltinListeners map. // configuration. The type is looked up in the BuiltinListeners map.
func NewListener(t string, config map[string]string) (net.Listener, map[string]string, ReloadFunc, error) { func NewListener(t string, config map[string]string, logger io.Writer) (net.Listener, map[string]string, ReloadFunc, error) {
f, ok := BuiltinListeners[t] f, ok := BuiltinListeners[t]
if !ok { if !ok {
return nil, nil, nil, fmt.Errorf("unknown listener type: %s", t) return nil, nil, nil, fmt.Errorf("unknown listener type: %s", t)
} }
return f(config) return f(config, logger)
} }
func listenerWrapTLS( func listenerWrapTLS(

View File

@@ -0,0 +1,58 @@
package server
import (
"io"
"net"
"github.com/hashicorp/scada-client/scada"
"github.com/hashicorp/vault/version"
)
type SCADAListener struct {
ln net.Listener
scadaProvider *scada.Provider
}
func (s *SCADAListener) Accept() (net.Conn, error) {
return s.ln.Accept()
}
func (s *SCADAListener) Close() error {
s.scadaProvider.Shutdown()
return s.ln.Close()
}
func (s *SCADAListener) Addr() net.Addr {
return s.ln.Addr()
}
func atlasListenerFactory(config map[string]string, logger io.Writer) (net.Listener, map[string]string, ReloadFunc, error) {
scadaConfig := &scada.Config{
Service: "vault",
Version: version.GetVersion().String(),
ResourceType: "vault-cluster",
Meta: map[string]string{},
Atlas: scada.AtlasConfig{
Endpoint: config["endpoint"],
Infrastructure: config["infrastructure"],
Token: config["token"],
},
}
provider, list, err := scada.NewHTTPProvider(scadaConfig, logger)
if err != nil {
return nil, nil, nil, err
}
ln := &SCADAListener{
ln: list,
scadaProvider: provider,
}
props := map[string]string{
"addr": "Atlas/SCADA",
"infrastructure": scadaConfig.Atlas.Infrastructure,
}
return listenerWrapTLS(ln, props, config)
}

View File

@@ -1,11 +1,12 @@
package server package server
import ( import (
"io"
"net" "net"
"time" "time"
) )
func tcpListenerFactory(config map[string]string) (net.Listener, map[string]string, ReloadFunc, error) { func tcpListenerFactory(config map[string]string, _ io.Writer) (net.Listener, map[string]string, ReloadFunc, error) {
addr, ok := config["address"] addr, ok := config["address"]
if !ok { if !ok {
addr = "127.0.0.1:8200" addr = "127.0.0.1:8200"

View File

@@ -16,7 +16,7 @@ func TestTCPListener(t *testing.T) {
ln, _, _, err := tcpListenerFactory(map[string]string{ ln, _, _, err := tcpListenerFactory(map[string]string{
"address": "127.0.0.1:0", "address": "127.0.0.1:0",
"tls_disable": "1", "tls_disable": "1",
}) }, nil)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@@ -52,7 +52,7 @@ func TestTCPListener_tls(t *testing.T) {
"address": "127.0.0.1:0", "address": "127.0.0.1:0",
"tls_cert_file": wd + "reload_foo.pem", "tls_cert_file": wd + "reload_foo.pem",
"tls_key_file": wd + "reload_foo.key", "tls_key_file": wd + "reload_foo.key",
}) }, nil)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }

View File

@@ -3,6 +3,12 @@ disable_mlock = true
statsd_addr = "bar" statsd_addr = "bar"
statsite_addr = "foo" statsite_addr = "foo"
listener "atlas" {
token = "foobar"
infrastructure = "foo/bar"
endpoint = "https://foo.bar:1111"
}
listener "tcp" { listener "tcp" {
address = "127.0.0.1:443" address = "127.0.0.1:443"
} }

View File

@@ -4,6 +4,11 @@
"address":"127.0.0.1:443" "address":"127.0.0.1:443"
} }
}, },
"atlas":{
"token":"foobar",
"infrastructure":"foo/bar",
"endpoint":"https://foo.bar:1111"
},
"backend":{ "backend":{
"consul":{ "consul":{
"foo":"bar" "foo":"bar"