agent: return a non-zero exit code on error (#9670)

* agent: return a non-zero exit code on error

* agent/template: always return on template server error, add case for error_on_missing_key

* agent: fix tests by updating Run params to use an errCh

* agent/template: add permission denied test case, clean up test var

* agent: use unbuffered errCh, emit fatal errors directly to the UI output

* agent: use oklog's run.Group to schedule subsystem runners (#9761)

* agent: use oklog's run.Group to schedule subsystem runners

* agent: clean up unused DoneCh, clean up agent's main Run func

* agent/template: use ts.stopped.CAS to atomically swap value

* fix tests

* fix tests

* agent/template: add timeout on TestRunServer

* agent: output error via logs and return a generic error on non-zero exit

* fix TestAgent_ExitAfterAuth

* agent/template: do not restart ct runner on new incoming token if exit_after_auth is set to true

* agent: drain ah.OutputCh after sink exits to avoid blocking on the channel

* use context.WithTimeout, expand comments around ordering of defer cancel()
This commit is contained in:
Calvin Leung Huang
2020-09-29 18:03:09 -07:00
committed by GitHub
parent b2927012ba
commit d54164f9e2
14 changed files with 519 additions and 188 deletions

View File

@@ -42,6 +42,7 @@ import (
"github.com/hashicorp/vault/sdk/version"
"github.com/kr/pretty"
"github.com/mitchellh/cli"
"github.com/oklog/run"
"github.com/posener/complete"
)
@@ -546,7 +547,21 @@ func (c *AgentCommand) Run(args []string) int {
// TODO: implement support for SIGHUP reloading of configuration
// signal.Notify(c.signalCh)
var ssDoneCh, ahDoneCh, tsDoneCh chan struct{}
var g run.Group
// This run group watches for signal termination
g.Add(func() error {
for {
select {
case <-c.ShutdownCh:
c.UI.Output("==> Vault agent shutdown triggered")
return nil
case <-ctx.Done():
return nil
}
}
}, func(error) {})
// Start auto-auth and sink servers
if method != nil {
enableTokenCh := len(config.Templates) > 0
@@ -557,14 +572,12 @@ func (c *AgentCommand) Run(args []string) int {
EnableReauthOnNewCredentials: config.AutoAuth.EnableReauthOnNewCredentials,
EnableTemplateTokenCh: enableTokenCh,
})
ahDoneCh = ah.DoneCh
ss := sink.NewSinkServer(&sink.SinkServerConfig{
Logger: c.logger.Named("sink.server"),
Client: client,
ExitAfterAuth: exitAfterAuth,
})
ssDoneCh = ss.DoneCh
ts := template.NewServer(&template.ServerConfig{
Logger: c.logger.Named("template.server"),
@@ -574,11 +587,46 @@ func (c *AgentCommand) Run(args []string) int {
Namespace: namespace,
ExitAfterAuth: exitAfterAuth,
})
tsDoneCh = ts.DoneCh
go ah.Run(ctx, method)
go ss.Run(ctx, ah.OutputCh, sinks)
go ts.Run(ctx, ah.TemplateTokenCh, config.Templates)
g.Add(func() error {
return ah.Run(ctx, method)
}, func(error) {
cancelFunc()
})
g.Add(func() error {
err := ss.Run(ctx, ah.OutputCh, sinks)
c.logger.Info("sinks finished, exiting")
// Start goroutine to drain from ah.OutputCh from this point onward
// to prevent ah.Run from being blocked.
go func() {
for {
select {
case <-ctx.Done():
return
case <-ah.OutputCh:
}
}
}()
// Wait until templates are rendered
if len(config.Templates) > 0 {
<-ts.DoneCh
}
return err
}, func(error) {
cancelFunc()
})
g.Add(func() error {
return ts.Run(ctx, ah.TemplateTokenCh, config.Templates)
}, func(error) {
cancelFunc()
ts.Stop()
})
}
// Server configuration output
@@ -609,27 +657,10 @@ func (c *AgentCommand) Run(args []string) int {
}
}()
select {
case <-ssDoneCh:
// This will happen if we exit-on-auth
c.logger.Info("sinks finished, exiting")
// allow any templates to be rendered
if tsDoneCh != nil {
<-tsDoneCh
}
case <-c.ShutdownCh:
c.UI.Output("==> Vault agent shutdown triggered")
cancelFunc()
if ahDoneCh != nil {
<-ahDoneCh
}
if ssDoneCh != nil {
<-ssDoneCh
}
if tsDoneCh != nil {
<-tsDoneCh
}
if err := g.Run(); err != nil {
c.logger.Error("runtime error encountered", "error", err)
c.UI.Error("Error encountered during run, refer to logs for more details.")
return 1
}
return 0

View File

@@ -66,11 +66,7 @@ func TestAliCloudEndToEnd(t *testing.T) {
t.Fatal(err)
}
ctx, cancelFunc := context.WithCancel(context.Background())
timer := time.AfterFunc(30*time.Second, func() {
cancelFunc()
})
defer timer.Stop()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
// We're going to feed alicloud auth creds via env variables.
if err := setAliCloudEnvCreds(); err != nil {
@@ -101,9 +97,18 @@ func TestAliCloudEndToEnd(t *testing.T) {
}
ah := auth.NewAuthHandler(ahConfig)
go ah.Run(ctx, am)
errCh := make(chan error)
go func() {
errCh <- ah.Run(ctx, am)
}()
defer func() {
<-ah.DoneCh
select {
case <-ctx.Done():
case err := <-errCh:
if err != nil {
t.Fatal(err)
}
}
}()
tmpFile, err := ioutil.TempFile("", "auth.tokensink.test.")
@@ -133,10 +138,25 @@ func TestAliCloudEndToEnd(t *testing.T) {
Logger: logger.Named("sink.server"),
Client: client,
})
go ss.Run(ctx, ah.OutputCh, []*sink.SinkConfig{config})
defer func() {
<-ss.DoneCh
go func() {
errCh <- ss.Run(ctx, ah.OutputCh, []*sink.SinkConfig{config})
}()
defer func() {
select {
case <-ctx.Done():
case err := <-errCh:
if err != nil {
t.Fatal(err)
}
}
}()
// This has to be after the other defers so it happens first. It allows
// successful test runs to immediately cancel all of the runner goroutines
// and unblock any of the blocking defer calls by the runner's DoneCh that
// comes before this and avoid successful tests from taking the entire
// timeout duration.
defer cancel()
if stat, err := os.Lstat(tokenSinkFileName); err == nil {
t.Fatalf("expected err but got %s", stat)

View File

@@ -177,11 +177,7 @@ func testAppRoleEndToEnd(t *testing.T, removeSecretIDFile bool, bindSecretID boo
os.Remove(out)
t.Logf("output: %s", out)
ctx, cancelFunc := context.WithCancel(context.Background())
timer := time.AfterFunc(30*time.Second, func() {
cancelFunc()
})
defer timer.Stop()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
secretFromAgent := secret
if secretIDLess {
@@ -224,9 +220,18 @@ func testAppRoleEndToEnd(t *testing.T, removeSecretIDFile bool, bindSecretID boo
Client: client,
}
ah := auth.NewAuthHandler(ahConfig)
go ah.Run(ctx, am)
errCh := make(chan error)
go func() {
errCh <- ah.Run(ctx, am)
}()
defer func() {
<-ah.DoneCh
select {
case <-ctx.Done():
case err := <-errCh:
if err != nil {
t.Fatal(err)
}
}
}()
config := &sink.SinkConfig{
@@ -245,13 +250,26 @@ func testAppRoleEndToEnd(t *testing.T, removeSecretIDFile bool, bindSecretID boo
Logger: logger.Named("sink.server"),
Client: client,
})
go ss.Run(ctx, ah.OutputCh, []*sink.SinkConfig{config})
go func() {
errCh <- ss.Run(ctx, ah.OutputCh, []*sink.SinkConfig{config})
}()
defer func() {
<-ss.DoneCh
select {
case <-ctx.Done():
case err := <-errCh:
if err != nil {
t.Fatal(err)
}
}
}()
// This has to be after the other defers so it happens first
defer cancelFunc()
// This has to be after the other defers so it happens first. It allows
// successful test runs to immediately cancel all of the runner goroutines
// and unblock any of the blocking defer calls by the runner's DoneCh that
// comes before this and avoid successful tests from taking the entire
// timeout duration.
defer cancel()
// Check that no sink file exists
_, err = os.Lstat(out)
@@ -503,11 +521,7 @@ func testAppRoleWithWrapping(t *testing.T, bindSecretID bool, secretIDLess bool,
os.Remove(out)
t.Logf("output: %s", out)
ctx, cancelFunc := context.WithCancel(context.Background())
timer := time.AfterFunc(30*time.Second, func() {
cancelFunc()
})
defer timer.Stop()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
secretFromAgent := secret
if secretIDLess {
@@ -550,9 +564,18 @@ func testAppRoleWithWrapping(t *testing.T, bindSecretID bool, secretIDLess bool,
Client: client,
}
ah := auth.NewAuthHandler(ahConfig)
go ah.Run(ctx, am)
errCh := make(chan error)
go func() {
errCh <- ah.Run(ctx, am)
}()
defer func() {
<-ah.DoneCh
select {
case <-ctx.Done():
case err := <-errCh:
if err != nil {
t.Fatal(err)
}
}
}()
config := &sink.SinkConfig{
@@ -571,13 +594,26 @@ func testAppRoleWithWrapping(t *testing.T, bindSecretID bool, secretIDLess bool,
Logger: logger.Named("sink.server"),
Client: client,
})
go ss.Run(ctx, ah.OutputCh, []*sink.SinkConfig{config})
defer func() {
<-ss.DoneCh
go func() {
errCh <- ss.Run(ctx, ah.OutputCh, []*sink.SinkConfig{config})
}()
// This has to be after the other defers so it happens first
defer cancelFunc()
defer func() {
select {
case <-ctx.Done():
case err := <-errCh:
if err != nil {
t.Fatal(err)
}
}
}()
// This has to be after the other defers so it happens first. It allows
// successful test runs to immediately cancel all of the runner goroutines
// and unblock any of the blocking defer calls by the runner's DoneCh that
// comes before this and avoid successful tests from taking the entire
// timeout duration.
defer cancel()
// Check that no sink file exists
_, err = os.Lstat(out)

View File

@@ -2,6 +2,7 @@ package auth
import (
"context"
"errors"
"math/rand"
"net/http"
"time"
@@ -39,7 +40,6 @@ type AuthConfig struct {
// AuthHandler is responsible for keeping a token alive and renewed and passing
// new tokens to the sink server
type AuthHandler struct {
DoneCh chan struct{}
OutputCh chan string
TemplateTokenCh chan string
logger hclog.Logger
@@ -60,7 +60,6 @@ type AuthHandlerConfig struct {
func NewAuthHandler(conf *AuthHandlerConfig) *AuthHandler {
ah := &AuthHandler{
DoneCh: make(chan struct{}),
// This is buffered so that if we try to output after the sink server
// has been shut down, during agent shutdown, we won't block
OutputCh: make(chan string, 1),
@@ -83,16 +82,15 @@ func backoffOrQuit(ctx context.Context, backoff time.Duration) {
}
}
func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) {
func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error {
if am == nil {
panic("nil auth method")
return errors.New("auth handler: nil auth method")
}
ah.logger.Info("starting auth handler")
defer func() {
am.Shutdown()
close(ah.OutputCh)
close(ah.DoneCh)
close(ah.TemplateTokenCh)
ah.logger.Info("auth handler stopped")
}()
@@ -122,7 +120,7 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) {
for {
select {
case <-ctx.Done():
return
return nil
default:
}

View File

@@ -79,7 +79,10 @@ func TestAuthHandler(t *testing.T) {
})
am := newUserpassTestMethod(t, client)
go ah.Run(ctx, am)
errCh := make(chan error)
go func() {
errCh <- ah.Run(ctx, am)
}()
// Consume tokens so we don't block
stopTime := time.Now().Add(5 * time.Second)
@@ -87,6 +90,11 @@ func TestAuthHandler(t *testing.T) {
consumption:
for {
select {
case err := <-errCh:
if err != nil {
t.Fatal(err)
}
break consumption
case <-ah.OutputCh:
case <-ah.TemplateTokenCh:
// Nothing
@@ -95,8 +103,6 @@ consumption:
cancelFunc()
closed = true
}
case <-ah.DoneCh:
break consumption
}
}
}

View File

@@ -79,11 +79,7 @@ func TestAWSEndToEnd(t *testing.T) {
t.Fatal(err)
}
ctx, cancelFunc := context.WithCancel(context.Background())
timer := time.AfterFunc(30*time.Second, func() {
cancelFunc()
})
defer timer.Stop()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
// We're going to feed aws auth creds via env variables.
if err := setAwsEnvCreds(); err != nil {
@@ -114,9 +110,18 @@ func TestAWSEndToEnd(t *testing.T) {
}
ah := auth.NewAuthHandler(ahConfig)
go ah.Run(ctx, am)
errCh := make(chan error)
go func() {
errCh <- ah.Run(ctx, am)
}()
defer func() {
<-ah.DoneCh
select {
case <-ctx.Done():
case err := <-errCh:
if err != nil {
t.Fatal(err)
}
}
}()
tmpFile, err := ioutil.TempFile("", "auth.tokensink.test.")
@@ -146,10 +151,25 @@ func TestAWSEndToEnd(t *testing.T) {
Logger: logger.Named("sink.server"),
Client: client,
})
go ss.Run(ctx, ah.OutputCh, []*sink.SinkConfig{config})
defer func() {
<-ss.DoneCh
go func() {
errCh <- ss.Run(ctx, ah.OutputCh, []*sink.SinkConfig{config})
}()
defer func() {
select {
case <-ctx.Done():
case err := <-errCh:
if err != nil {
t.Fatal(err)
}
}
}()
// This has to be after the other defers so it happens first. It allows
// successful test runs to immediately cancel all of the runner goroutines
// and unblock any of the blocking defer calls by the runner's DoneCh that
// comes before this and avoid successful tests from taking the entire
// timeout duration.
defer cancel()
if stat, err := os.Lstat(tokenSinkFileName); err == nil {
t.Fatalf("expected err but got %s", stat)

View File

@@ -151,11 +151,7 @@ func TestCache_UsingAutoAuthToken(t *testing.T) {
os.Remove(out)
t.Logf("output: %s", out)
ctx, cancelFunc := context.WithCancel(context.Background())
timer := time.AfterFunc(30*time.Second, func() {
cancelFunc()
})
defer timer.Stop()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
conf := map[string]interface{}{
"role_id_file_path": role,
@@ -199,9 +195,18 @@ func TestCache_UsingAutoAuthToken(t *testing.T) {
Client: client,
}
ah := auth.NewAuthHandler(ahConfig)
go ah.Run(ctx, am)
errCh := make(chan error)
go func() {
errCh <- ah.Run(ctx, am)
}()
defer func() {
<-ah.DoneCh
select {
case <-ctx.Done():
case err := <-errCh:
if err != nil {
t.Fatal(err)
}
}
}()
config := &sink.SinkConfig{
@@ -231,13 +236,25 @@ func TestCache_UsingAutoAuthToken(t *testing.T) {
}
inmemSinkConfig.Sink = inmemSink
go ss.Run(ctx, ah.OutputCh, []*sink.SinkConfig{config, inmemSinkConfig})
go func() {
errCh <- ss.Run(ctx, ah.OutputCh, []*sink.SinkConfig{config, inmemSinkConfig})
}()
defer func() {
<-ss.DoneCh
select {
case <-ctx.Done():
case err := <-errCh:
if err != nil {
t.Fatal(err)
}
}
}()
// This has to be after the other defers so it happens first
defer cancelFunc()
// This has to be after the other defers so it happens first. It allows
// successful test runs to immediately cancel all of the runner goroutines
// and unblock any of the blocking defer calls by the runner's DoneCh that
// comes before this and avoid successful tests from taking the entire
// timeout duration.
defer cancel()
// Check that no sink file exists
_, err = os.Lstat(out)

View File

@@ -135,11 +135,7 @@ func testCertEndToEnd(t *testing.T, withCertRoleName, ahWrapping bool) {
logger.Trace("wrote dh param file", "path", dhpath)
}
ctx, cancelFunc := context.WithCancel(context.Background())
timer := time.AfterFunc(30*time.Second, func() {
cancelFunc()
})
defer timer.Stop()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
aaConfig := map[string]interface{}{}
@@ -165,9 +161,18 @@ func testCertEndToEnd(t *testing.T, withCertRoleName, ahWrapping bool) {
ahConfig.WrapTTL = 10 * time.Second
}
ah := auth.NewAuthHandler(ahConfig)
go ah.Run(ctx, am)
errCh := make(chan error)
go func() {
errCh <- ah.Run(ctx, am)
}()
defer func() {
<-ah.DoneCh
select {
case <-ctx.Done():
case err := <-errCh:
if err != nil {
t.Fatal(err)
}
}
}()
config := &sink.SinkConfig{
@@ -193,13 +198,25 @@ func testCertEndToEnd(t *testing.T, withCertRoleName, ahWrapping bool) {
Logger: logger.Named("sink.server"),
Client: client,
})
go ss.Run(ctx, ah.OutputCh, []*sink.SinkConfig{config})
go func() {
errCh <- ss.Run(ctx, ah.OutputCh, []*sink.SinkConfig{config})
}()
defer func() {
<-ss.DoneCh
select {
case <-ctx.Done():
case err := <-errCh:
if err != nil {
t.Fatal(err)
}
}
}()
// This has to be after the other defers so it happens first
defer cancelFunc()
// This has to be after the other defers so it happens first. It allows
// successful test runs to immediately cancel all of the runner goroutines
// and unblock any of the blocking defer calls by the runner's DoneCh that
// comes before this and avoid successful tests from taking the entire
// timeout duration.
defer cancel()
cloned, err := client.Clone()
if err != nil {
@@ -454,11 +471,7 @@ func TestCertEndToEnd_CertsInConfig(t *testing.T) {
// Auth handler (auto-auth) setup
// /////////////
ctx, cancelFunc := context.WithCancel(context.Background())
timer := time.AfterFunc(30*time.Second, func() {
cancelFunc()
})
defer timer.Stop()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
am, err := agentcert.NewCertAuthMethod(&auth.AuthConfig{
Logger: logger.Named("auth.cert"),
@@ -480,9 +493,18 @@ func TestCertEndToEnd_CertsInConfig(t *testing.T) {
}
ah := auth.NewAuthHandler(ahConfig)
go ah.Run(ctx, am)
errCh := make(chan error)
go func() {
errCh <- ah.Run(ctx, am)
}()
defer func() {
<-ah.DoneCh
select {
case <-ctx.Done():
case err := <-errCh:
if err != nil {
t.Fatal(err)
}
}
}()
// /////////////
@@ -515,13 +537,25 @@ func TestCertEndToEnd_CertsInConfig(t *testing.T) {
Logger: logger.Named("sink.server"),
Client: client,
})
go ss.Run(ctx, ah.OutputCh, []*sink.SinkConfig{config})
go func() {
errCh <- ss.Run(ctx, ah.OutputCh, []*sink.SinkConfig{config})
}()
defer func() {
<-ss.DoneCh
select {
case <-ctx.Done():
case err := <-errCh:
if err != nil {
t.Fatal(err)
}
}
}()
// This has to be after the other defers so it happens first
defer cancelFunc()
// This has to be after the other defers so it happens first. It allows
// successful test runs to immediately cancel all of the runner goroutines
// and unblock any of the blocking defer calls by the runner's DoneCh that
// comes before this and avoid successful tests from taking the entire
// timeout duration.
defer cancel()
// Read the token from the sink
timeout := time.Now().Add(5 * time.Second)

View File

@@ -90,11 +90,7 @@ func TestCFEndToEnd(t *testing.T) {
os.Setenv(credCF.EnvVarInstanceCertificate, testCFCerts.PathToInstanceCertificate)
os.Setenv(credCF.EnvVarInstanceKey, testCFCerts.PathToInstanceKey)
ctx, cancelFunc := context.WithCancel(context.Background())
timer := time.AfterFunc(30*time.Second, func() {
cancelFunc()
})
defer timer.Stop()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
am, err := agentcf.NewCFAuthMethod(&auth.AuthConfig{
MountPath: "auth/cf",
@@ -112,9 +108,18 @@ func TestCFEndToEnd(t *testing.T) {
}
ah := auth.NewAuthHandler(ahConfig)
go ah.Run(ctx, am)
errCh := make(chan error)
go func() {
errCh <- ah.Run(ctx, am)
}()
defer func() {
<-ah.DoneCh
select {
case <-ctx.Done():
case err := <-errCh:
if err != nil {
t.Fatal(err)
}
}
}()
tmpFile, err := ioutil.TempFile("", "auth.tokensink.test.")
@@ -144,10 +149,25 @@ func TestCFEndToEnd(t *testing.T) {
Logger: logger.Named("sink.server"),
Client: client,
})
go ss.Run(ctx, ah.OutputCh, []*sink.SinkConfig{config})
defer func() {
<-ss.DoneCh
go func() {
errCh <- ss.Run(ctx, ah.OutputCh, []*sink.SinkConfig{config})
}()
defer func() {
select {
case <-ctx.Done():
case err := <-errCh:
if err != nil {
t.Fatal(err)
}
}
}()
// This has to be after the other defers so it happens first. It allows
// successful test runs to immediately cancel all of the runner goroutines
// and unblock any of the blocking defer calls by the runner's DoneCh that
// comes before this and avoid successful tests from taking the entire
// timeout duration.
defer cancel()
if stat, err := os.Lstat(tokenSinkFileName); err == nil {
t.Fatalf("expected err but got %s", stat)

View File

@@ -121,11 +121,7 @@ func testJWTEndToEnd(t *testing.T, ahWrapping bool) {
logger.Trace("wrote dh param file", "path", dhpath)
}
ctx, cancelFunc := context.WithCancel(context.Background())
timer := time.AfterFunc(30*time.Second, func() {
cancelFunc()
})
defer timer.Stop()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
am, err := agentjwt.NewJWTAuthMethod(&auth.AuthConfig{
Logger: logger.Named("auth.jwt"),
@@ -148,9 +144,18 @@ func testJWTEndToEnd(t *testing.T, ahWrapping bool) {
ahConfig.WrapTTL = 10 * time.Second
}
ah := auth.NewAuthHandler(ahConfig)
go ah.Run(ctx, am)
errCh := make(chan error)
go func() {
errCh <- ah.Run(ctx, am)
}()
defer func() {
<-ah.DoneCh
select {
case <-ctx.Done():
case err := <-errCh:
if err != nil {
t.Fatal(err)
}
}
}()
config := &sink.SinkConfig{
@@ -176,13 +181,25 @@ func testJWTEndToEnd(t *testing.T, ahWrapping bool) {
Logger: logger.Named("sink.server"),
Client: client,
})
go ss.Run(ctx, ah.OutputCh, []*sink.SinkConfig{config})
go func() {
errCh <- ss.Run(ctx, ah.OutputCh, []*sink.SinkConfig{config})
}()
defer func() {
<-ss.DoneCh
select {
case <-ctx.Done():
case err := <-errCh:
if err != nil {
t.Fatal(err)
}
}
}()
// This has to be after the other defers so it happens first
defer cancelFunc()
// This has to be after the other defers so it happens first. It allows
// successful test runs to immediately cancel all of the runner goroutines
// and unblock any of the blocking defer calls by the runner's DoneCh that
// comes before this and avoid successful tests from taking the entire
// timeout duration.
defer cancel()
// Check that no jwt file exists
_, err = os.Lstat(in)

View File

@@ -33,17 +33,25 @@ func TestSinkServer(t *testing.T) {
uuidStr, _ := uuid.GenerateUUID()
in := make(chan string)
sinks := []*sink.SinkConfig{fs1, fs2}
go ss.Run(ctx, in, sinks)
errCh := make(chan error)
go func() {
errCh <- ss.Run(ctx, in, sinks)
}()
// Seed a token
in <- uuidStr
// Give it time to finish writing
time.Sleep(1 * time.Second)
// Tell it to shut down and give it time to do so
cancelFunc()
<-ss.DoneCh
timer := time.AfterFunc(3*time.Second, func() {
cancelFunc()
})
defer timer.Stop()
select {
case <-ctx.Done():
case err := <-errCh:
t.Fatal(err)
}
for _, path := range []string{path1, path2} {
fileBytes, err := ioutil.ReadFile(fmt.Sprintf("%s/token", path))
@@ -91,7 +99,10 @@ func TestSinkServerRetry(t *testing.T) {
in := make(chan string)
sinks := []*sink.SinkConfig{&sink.SinkConfig{Sink: b1}, &sink.SinkConfig{Sink: b2}}
go ss.Run(ctx, in, sinks)
errCh := make(chan error)
go func() {
errCh <- ss.Run(ctx, in, sinks)
}()
// Seed a token
in <- "bad"
@@ -117,5 +128,10 @@ func TestSinkServerRetry(t *testing.T) {
// Tell it to shut down and give it time to do so
cancelFunc()
<-ss.DoneCh
select {
case err := <-errCh:
if err != nil {
t.Fatal(err)
}
}
}

View File

@@ -48,7 +48,6 @@ type SinkServerConfig struct {
// SinkServer is responsible for pushing tokens to sinks
type SinkServer struct {
DoneCh chan struct{}
logger hclog.Logger
client *api.Client
random *rand.Rand
@@ -58,7 +57,6 @@ type SinkServer struct {
func NewSinkServer(conf *SinkServerConfig) *SinkServer {
ss := &SinkServer{
DoneCh: make(chan struct{}),
logger: conf.Logger,
client: conf.Client,
random: rand.New(rand.NewSource(int64(time.Now().Nanosecond()))),
@@ -71,7 +69,7 @@ func NewSinkServer(conf *SinkServerConfig) *SinkServer {
// Run executes the server's run loop, which is responsible for reading
// in new tokens and pushing them out to the various sinks.
func (ss *SinkServer) Run(ctx context.Context, incoming chan string, sinks []*SinkConfig) {
func (ss *SinkServer) Run(ctx context.Context, incoming chan string, sinks []*SinkConfig) error {
latestToken := new(string)
writeSink := func(currSink *SinkConfig, currToken string) error {
if currToken != *latestToken {
@@ -95,13 +93,12 @@ func (ss *SinkServer) Run(ctx context.Context, incoming chan string, sinks []*Si
}
if incoming == nil {
panic("incoming channel is nil")
return errors.New("sink server: incoming channel is nil")
}
ss.logger.Info("starting sink server")
defer func() {
ss.logger.Info("sink server stopped")
close(ss.DoneCh)
}()
type sinkToken struct {
@@ -112,7 +109,7 @@ func (ss *SinkServer) Run(ctx context.Context, incoming chan string, sinks []*Si
for {
select {
case <-ctx.Done():
return
return nil
case token := <-incoming:
if len(sinks) > 0 {
@@ -140,14 +137,14 @@ func (ss *SinkServer) Run(ctx context.Context, incoming chan string, sinks []*Si
ss.logger.Trace("no sinks, ignoring new token")
if ss.exitAfterAuth {
ss.logger.Trace("no sinks, exitAfterAuth, bye")
return
return nil
}
}
case st := <-sinkCh:
atomic.AddInt32(ss.remaining, -1)
select {
case <-ctx.Done():
return
return nil
default:
}
@@ -156,14 +153,14 @@ func (ss *SinkServer) Run(ctx context.Context, incoming chan string, sinks []*Si
ss.logger.Error("error returned by sink function, retrying", "error", err, "backoff", backoff.String())
select {
case <-ctx.Done():
return
return nil
case <-time.After(backoff):
atomic.AddInt32(ss.remaining, 1)
sinkCh <- st
}
} else {
if atomic.LoadInt32(ss.remaining) == 0 && ss.exitAfterAuth {
return
return nil
}
}
}

View File

@@ -7,9 +7,13 @@ package template
import (
"context"
"errors"
"fmt"
"io"
"strings"
"go.uber.org/atomic"
ctconfig "github.com/hashicorp/consul-template/config"
ctlogging "github.com/hashicorp/consul-template/logging"
"github.com/hashicorp/consul-template/manager"
@@ -45,25 +49,35 @@ type Server struct {
config *ServerConfig
// runner is the consul-template runner
runner *manager.Runner
runner *manager.Runner
runnerStarted *atomic.Bool
// Templates holds the parsed Consul Templates
Templates []*ctconfig.TemplateConfig
// lookupMap is alist of templates indexed by their consul-template ID. This
// lookupMap is a list of templates indexed by their consul-template ID. This
// is used to ensure all Vault templates have been rendered before returning
// from the runner in the event we're using exit after auth.
lookupMap map[string][]*ctconfig.TemplateConfig
DoneCh chan struct{}
DoneCh chan struct{}
stopped *atomic.Bool
logger hclog.Logger
exitAfterAuth bool
// testingLimitRetry is used for tests to limit the number of retries
// performed by the template server
testingLimitRetry int
}
// NewServer returns a new configured server
func NewServer(conf *ServerConfig) *Server {
ts := Server{
DoneCh: make(chan struct{}),
stopped: atomic.NewBool(false),
runnerStarted: atomic.NewBool(false),
logger: conf.Logger,
config: conf,
exitAfterAuth: conf.ExitAfterAuth,
@@ -74,23 +88,23 @@ func NewServer(conf *ServerConfig) *Server {
// Run kicks off the internal Consul Template runner, and listens for changes to
// the token from the AuthHandler. If Done() is called on the context, shut down
// the Runner and return
func (ts *Server) Run(ctx context.Context, incoming chan string, templates []*ctconfig.TemplateConfig) {
latestToken := new(string)
ts.logger.Info("starting template server")
// defer the closing of the DoneCh
defer func() {
ts.logger.Info("template server stopped")
close(ts.DoneCh)
}()
func (ts *Server) Run(ctx context.Context, incoming chan string, templates []*ctconfig.TemplateConfig) error {
if incoming == nil {
panic("incoming channel is nil")
return errors.New("template server: incoming channel is nil")
}
// If there are no templates, return
latestToken := new(string)
ts.logger.Info("starting template server")
defer func() {
ts.logger.Info("template server stopped")
}()
// If there are no templates, we wait for context cancellation and then return
if len(templates) == 0 {
ts.logger.Info("no templates found")
return
<-ctx.Done()
return nil
}
// construct a consul template vault config based the agents vault
@@ -98,15 +112,13 @@ func (ts *Server) Run(ctx context.Context, incoming chan string, templates []*ct
var runnerConfig *ctconfig.Config
var runnerConfigErr error
if runnerConfig, runnerConfigErr = newRunnerConfig(ts.config, templates); runnerConfigErr != nil {
ts.logger.Error("template server failed to generate runner config", "error", runnerConfigErr)
return
return fmt.Errorf("template server failed to runner generate config: %w", runnerConfigErr)
}
var err error
ts.runner, err = manager.NewRunner(runnerConfig, false)
if err != nil {
ts.logger.Error("template server failed to create", "error", err)
return
return fmt.Errorf("template server failed to create: %w", err)
}
// Build the lookup map using the id mapping from the Template runner. This is
@@ -130,11 +142,20 @@ func (ts *Server) Run(ctx context.Context, incoming chan string, templates []*ct
select {
case <-ctx.Done():
ts.runner.Stop()
return
return nil
case token := <-incoming:
if token != *latestToken {
ts.logger.Info("template server received new token")
// If the runner was previously started and we intend to exit
// after auth, do not restart the runner if a new token is
// received.
if ts.exitAfterAuth && ts.runnerStarted.Load() {
ts.logger.Info("template server not restarting with new token with exit_after_auth set to true")
continue
}
ts.runner.Stop()
*latestToken = token
ctv := ctconfig.Config{
@@ -142,6 +163,13 @@ func (ts *Server) Run(ctx context.Context, incoming chan string, templates []*ct
Token: latestToken,
},
}
// If we're testing, limit retries to 3 attempts to avoid
// long test runs from exponential back-offs
if ts.testingLimitRetry != 0 {
ctv.Vault.Retry = &ctconfig.RetryConfig{Attempts: &ts.testingLimitRetry}
}
runnerConfig = runnerConfig.Merge(&ctv)
var runnerErr error
ts.runner, runnerErr = manager.NewRunner(runnerConfig, false)
@@ -149,17 +177,14 @@ func (ts *Server) Run(ctx context.Context, incoming chan string, templates []*ct
ts.logger.Error("template server failed with new Vault token", "error", runnerErr)
continue
}
ts.runnerStarted.CAS(false, true)
go ts.runner.Start()
}
case err := <-ts.runner.ErrCh:
ts.logger.Error("template server error", "error", err.Error())
ts.runner.StopImmediately()
ts.runner, err = manager.NewRunner(runnerConfig, false)
if err != nil {
ts.logger.Error("template server failed to create", "error", err)
return
}
go ts.runner.Start()
return fmt.Errorf("template server: %w", err)
case <-ts.runner.TemplateRenderedCh():
// A template has been rendered, figure out what to do
events := ts.runner.RenderEvents()
@@ -185,12 +210,18 @@ func (ts *Server) Run(ctx context.Context, incoming chan string, templates []*ct
// return. The deferred closing of the DoneCh will allow agent to
// continue with closing down
ts.runner.Stop()
return
return nil
}
}
}
}
func (ts *Server) Stop() {
if ts.stopped.CAS(false, true) {
close(ts.DoneCh)
}
}
// newRunnerConfig returns a consul-template runner configuration, setting the
// Vault and Consul configurations based on the clients configs.
func newRunnerConfig(sc *ServerConfig, templates ctconfig.TemplateConfigs) (*ctconfig.Config, error) {

View File

@@ -9,6 +9,7 @@ import (
"net/http/httptest"
"os"
"testing"
"time"
ctconfig "github.com/hashicorp/consul-template/config"
"github.com/hashicorp/go-hclog"
@@ -28,7 +29,20 @@ func TestNewServer(t *testing.T) {
func TestServerRun(t *testing.T) {
// create http test server
ts := httptest.NewServer(http.HandlerFunc(handleRequest))
mux := http.NewServeMux()
mux.HandleFunc("/v1/kv/myapp/config", func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, jsonResponse)
})
mux.HandleFunc("/v1/kv/myapp/config-bad", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(404)
fmt.Fprintln(w, `{"errors":[]}`)
})
mux.HandleFunc("/v1/kv/myapp/perm-denied", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(403)
fmt.Fprintln(w, `{"errors":["1 error occurred:\n\t* permission denied\n\n"]}`)
})
ts := httptest.NewServer(mux)
defer ts.Close()
tmpDir, err := ioutil.TempDir("", "agent-tests")
defer os.RemoveAll(tmpDir)
@@ -50,6 +64,7 @@ func TestServerRun(t *testing.T) {
testCases := map[string]struct {
templateMap map[string]*templateTest
expectError bool
}{
"simple": {
templateMap: map[string]*templateTest{
@@ -59,6 +74,7 @@ func TestServerRun(t *testing.T) {
},
},
},
expectError: false,
},
"multiple": {
templateMap: map[string]*templateTest{
@@ -98,6 +114,38 @@ func TestServerRun(t *testing.T) {
},
},
},
expectError: false,
},
"bad secret": {
templateMap: map[string]*templateTest{
"render_01": &templateTest{
template: &ctconfig.TemplateConfig{
Contents: pointerutil.StringPtr(templateContentsBad),
},
},
},
expectError: true,
},
"missing key": {
templateMap: map[string]*templateTest{
"render_01": &templateTest{
template: &ctconfig.TemplateConfig{
Contents: pointerutil.StringPtr(templateContentsMissingKey),
ErrMissingKey: pointerutil.BoolPtr(true),
},
},
},
expectError: true,
},
"permission denied": {
templateMap: map[string]*templateTest{
"render_01": &templateTest{
template: &ctconfig.TemplateConfig{
Contents: pointerutil.StringPtr(templateContentsPermDenied),
},
},
},
expectError: true,
},
}
@@ -111,7 +159,7 @@ func TestServerRun(t *testing.T) {
templatesToRender = append(templatesToRender, templateTest.template)
}
ctx := context.Background()
ctx, _ := context.WithTimeout(context.Background(), 20*time.Second)
sc := ServerConfig{
Logger: logging.NewVaultLogger(hclog.Trace),
VaultConf: &config.Vault{
@@ -127,13 +175,29 @@ func TestServerRun(t *testing.T) {
if ts == nil {
t.Fatal("nil server returned")
}
server.testingLimitRetry = 3
go server.Run(ctx, templateTokenCh, templatesToRender)
errCh := make(chan error)
go func() {
errCh <- server.Run(ctx, templateTokenCh, templatesToRender)
}()
// send a dummy value to trigger the internal Runner to query for secret
// info
templateTokenCh <- "test"
<-server.DoneCh
select {
case <-ctx.Done():
t.Fatal("timeout reached before templates were rendered")
case err := <-errCh:
if err != nil && !tc.expectError {
t.Fatalf("did not expect error, got: %v", err)
}
if err != nil && tc.expectError {
t.Logf("received expected error: %v", err)
return
}
}
// verify test file exists and has the content we're looking for
var fileCount int
@@ -162,10 +226,6 @@ func TestServerRun(t *testing.T) {
}
}
func handleRequest(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, jsonResponse)
}
var jsonResponse = `
{
"request_id": "8af096e9-518c-7351-eff5-5ba20554b21f",
@@ -199,3 +259,31 @@ var templateContents = `
}
{{ end }}
`
var templateContentsMissingKey = `
{{ with secret "kv/myapp/config"}}
{
{{ if .Data.data.foo}}"foo":"{{ .Data.data.foo}}"{{ end }}
}
{{ end }}
`
var templateContentsBad = `
{{ with secret "kv/myapp/config-bad"}}
{
{{ if .Data.data.username}}"username":"{{ .Data.data.username}}",{{ end }}
{{ if .Data.data.password }}"password":"{{ .Data.data.password }}",{{ end }}
{{ if .Data.metadata.version}}"version":"{{ .Data.metadata.version }}"{{ end }}
}
{{ end }}
`
var templateContentsPermDenied = `
{{ with secret "kv/myapp/perm-denied"}}
{
{{ if .Data.data.username}}"username":"{{ .Data.data.username}}",{{ end }}
{{ if .Data.data.password }}"password":"{{ .Data.data.password }}",{{ end }}
{{ if .Data.metadata.version}}"version":"{{ .Data.metadata.version }}"{{ end }}
}
{{ end }}
`