From d22204914d95d80e74effd94150586ed1ba8e535 Mon Sep 17 00:00:00 2001 From: vishalnayak Date: Wed, 20 Jul 2016 15:38:53 -0400 Subject: [PATCH] Add service discovery to init command --- api/sys_init.go | 2 +- command/init.go | 146 ++++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 131 insertions(+), 17 deletions(-) diff --git a/api/sys_init.go b/api/sys_init.go index 47e9718247..37c2bcc8ca 100644 --- a/api/sys_init.go +++ b/api/sys_init.go @@ -45,7 +45,7 @@ type InitStatusResponse struct { } type InitResponse struct { - Keys []string + Keys []string `json:"keys"` RecoveryKeys []string `json:"recovery_keys"` RootToken string `json:"root_token"` } diff --git a/command/init.go b/command/init.go index a839973de7..e2e6bc960d 100644 --- a/command/init.go +++ b/command/init.go @@ -2,8 +2,11 @@ package command import ( "fmt" + "os" + "runtime" "strings" + consulapi "github.com/hashicorp/consul/api" "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/helper/pgpkeys" "github.com/hashicorp/vault/meta" @@ -18,6 +21,7 @@ func (c *InitCommand) Run(args []string) int { var threshold, shares, storedShares, recoveryThreshold, recoveryShares int var pgpKeys, recoveryPgpKeys pgpkeys.PubKeyFilesFlag var check bool + var auto string flags := c.Meta.FlagSet("init", meta.FlagSetDefault) flags.Usage = func() { c.Ui.Error(c.Help()) } flags.IntVar(&shares, "key-shares", 5, "") @@ -28,10 +32,128 @@ func (c *InitCommand) Run(args []string) int { flags.IntVar(&recoveryThreshold, "recovery-threshold", 3, "") flags.Var(&recoveryPgpKeys, "recovery-pgp-keys", "") flags.BoolVar(&check, "check", false, "") + flags.StringVar(&auto, "auto", "", "") if err := flags.Parse(args); err != nil { return 1 } + initRequest := &api.InitRequest{ + SecretShares: shares, + SecretThreshold: threshold, + StoredShares: storedShares, + PGPKeys: pgpKeys, + RecoveryShares: recoveryShares, + RecoveryThreshold: recoveryThreshold, + RecoveryPGPKeys: recoveryPgpKeys, + } + + // If running in 'auto' mode, run service discovery based on environment + // variables of Consul. + if auto != "" { + // Create configuration for Consul + consulConfig := consulapi.DefaultConfig() + + // Create a client to communicate with Consul + consulClient, err := consulapi.NewClient(consulConfig) + if err != nil { + c.Ui.Error(fmt.Sprintf("failed to create Consul client:%v", err)) + return 1 + } + + var uninitializedVaults []string + var initializedVault string + + // Query the nodes belonging to the cluster + if services, _, err := consulClient.Catalog().Service(auto, "", &consulapi.QueryOptions{AllowStale: true}); err == nil { + Loop: + for _, service := range services { + vaultAddress := fmt.Sprintf("%s://%s:%d", consulConfig.Scheme, service.ServiceAddress, service.ServicePort) + + // Set VAULT_ADDR to the discovered node + os.Setenv(api.EnvVaultAddress, vaultAddress) + + // Create a client to communicate with the discovered node + client, err := c.Client() + if err != nil { + c.Ui.Error(fmt.Sprintf( + "Error initializing client: %s", err)) + return 1 + } + + // Check the initialization status of the discovered node + inited, err := client.Sys().InitStatus() + switch { + case err != nil: + c.Ui.Error(fmt.Sprintf("Error checking initialization status of discovered node: %s err:%s", vaultAddress, err)) + return 1 + case inited: + // One of the nodes in the cluster is initialized. Break out. + initializedVault = vaultAddress + break Loop + default: + // Vault is uninitialized. + uninitializedVaults = append(uninitializedVaults, vaultAddress) + } + } + } + + export := "export" + quote := "'" + if runtime.GOOS == "windows" { + export = "set" + quote = "" + } + + if initializedVault != "" { + c.Ui.Output(fmt.Sprintf("Discovered an initialized Vault node at '%s'\n", initializedVault)) + c.Ui.Output("Set the following environment variable to operate on the discovered Vault:\n") + c.Ui.Output(fmt.Sprintf("\t%s VAULT_ADDR=%shttp://%s%s", export, quote, initializedVault, quote)) + return 0 + } + + switch len(uninitializedVaults) { + case 0: + c.Ui.Error(fmt.Sprintf("Failed to discover Vault nodes under the service name '%s'", auto)) + return 1 + case 1: + // There was only one node found in the Vault cluster and it + // was uninitialized. + + // Set the VAULT_ADDR to the discovered node. This will ensure + // that the client created will operate on the discovered node. + os.Setenv(api.EnvVaultAddress, uninitializedVaults[0]) + + // Let the client know that initialization is perfomed on the + // discovered node. + c.Ui.Output(fmt.Sprintf("Discovered Vault at '%s'\n", uninitializedVaults[0])) + + // Attempt initializing it + ret := c.runInit(check, initRequest) + + // Regardless of success or failure, instruct client to update VAULT_ADDR + c.Ui.Output("Set the following environment variable to operate on the discovered Vault:\n") + c.Ui.Output(fmt.Sprintf("\t%s VAULT_ADDR=%shttp://%s%s", export, quote, uninitializedVaults[0], quote)) + + return ret + default: + // If more than one Vault node were discovered, print out all of them, + // requiring the client to update VAULT_ADDR and to run init again. + c.Ui.Output(fmt.Sprintf("Discovered more than one uninitialized Vaults under the service name '%s'\n", auto)) + c.Ui.Output("To initialize all Vaults, set any *one* of the following and run 'vault init':") + + // Print valid commands to make setting the variables easier + for _, vaultNode := range uninitializedVaults { + c.Ui.Output(fmt.Sprintf("\t%s VAULT_ADDR=%shttp://%s%s", export, quote, vaultNode, quote)) + + } + return 0 + } + } + + return c.runInit(check, initRequest) +} + +func (c *InitCommand) runInit(check bool, initRequest *api.InitRequest) int { client, err := c.Client() if err != nil { c.Ui.Error(fmt.Sprintf( @@ -43,15 +165,7 @@ func (c *InitCommand) Run(args []string) int { return c.checkStatus(client) } - resp, err := client.Sys().Init(&api.InitRequest{ - SecretShares: shares, - SecretThreshold: threshold, - StoredShares: storedShares, - PGPKeys: pgpKeys, - RecoveryShares: recoveryShares, - RecoveryThreshold: recoveryThreshold, - RecoveryPGPKeys: recoveryPgpKeys, - }) + resp, err := client.Sys().Init(initRequest) if err != nil { c.Ui.Error(fmt.Sprintf( "Error initializing Vault: %s", err)) @@ -67,7 +181,7 @@ func (c *InitCommand) Run(args []string) int { c.Ui.Output(fmt.Sprintf("Initial Root Token: %s", resp.RootToken)) - if storedShares < 1 { + if initRequest.StoredShares < 1 { c.Ui.Output(fmt.Sprintf( "\n"+ "Vault initialized with %d keys and a key threshold of %d. Please\n"+ @@ -76,10 +190,10 @@ func (c *InitCommand) Run(args []string) int { "to unseal it again.\n\n"+ "Vault does not store the master key. Without at least %d keys,\n"+ "your Vault will remain permanently sealed.", - shares, - threshold, - threshold, - threshold, + initRequest.SecretShares, + initRequest.SecretThreshold, + initRequest.SecretThreshold, + initRequest.SecretThreshold, )) } else { c.Ui.Output( @@ -92,8 +206,8 @@ func (c *InitCommand) Run(args []string) int { "\n"+ "Recovery key initialized with %d keys and a key threshold of %d. Please\n"+ "securely distribute the above keys.", - recoveryShares, - recoveryThreshold, + initRequest.RecoveryShares, + initRequest.RecoveryThreshold, )) }