From dc9d3f30124e9affb4a5c3b6be3240df12028d2f Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Wed, 22 Aug 2018 14:37:40 -0400 Subject: [PATCH] Sync some ns stuff to api/command --- api/client.go | 3 +- api/sys_generate_root.go | 2 + api/sys_namespaces.go | 97 ----------------- command/base.go | 31 +++++- command/base_flags.go | 24 ++--- command/namespace.go | 6 +- command/namespace_create.go | 11 +- command/namespace_delete.go | 18 +++- command/namespace_list.go | 30 ++++-- command/namespace_lookup.go | 22 ++-- command/operator_generate_root.go | 144 ++++++++++++++++--------- command/operator_generate_root_test.go | 23 +--- helper/consts/consts.go | 7 ++ helper/namespace/namespace.go | 52 ++++++--- 14 files changed, 237 insertions(+), 233 deletions(-) delete mode 100644 api/sys_namespaces.go diff --git a/api/client.go b/api/client.go index cc27444bcb..c7ced82373 100644 --- a/api/client.go +++ b/api/client.go @@ -19,6 +19,7 @@ import ( "github.com/hashicorp/go-cleanhttp" retryablehttp "github.com/hashicorp/go-retryablehttp" "github.com/hashicorp/go-rootcerts" + "github.com/hashicorp/vault/helper/consts" "github.com/hashicorp/vault/helper/parseutil" "golang.org/x/net/http2" "golang.org/x/time/rate" @@ -474,7 +475,7 @@ func (c *Client) SetNamespace(namespace string) { c.headers = make(http.Header) } - c.headers.Set("X-Vault-Namespace", namespace) + c.headers.Set(consts.NamespaceHeaderName, namespace) } // Token returns the access token being used by this client. It will diff --git a/api/sys_generate_root.go b/api/sys_generate_root.go index ec7314da10..66f72dff69 100644 --- a/api/sys_generate_root.go +++ b/api/sys_generate_root.go @@ -119,4 +119,6 @@ type GenerateRootStatusResponse struct { EncodedToken string `json:"encoded_token"` EncodedRootToken string `json:"encoded_root_token"` PGPFingerprint string `json:"pgp_fingerprint"` + OTP string `json:"otp"` + OTPLength int `json:"otp_length"` } diff --git a/api/sys_namespaces.go b/api/sys_namespaces.go deleted file mode 100644 index 5d86a5514f..0000000000 --- a/api/sys_namespaces.go +++ /dev/null @@ -1,97 +0,0 @@ -package api - -import ( - "fmt" - "net/http" -) - -// ListNamespacesResponse is the response from the ListNamespaces call. -type ListNamespacesResponse struct { - // NamespacePaths is the list of child namespace paths - NamespacePaths []string `json:"namespace_paths"` -} - -type GetNamespaceResponse struct { - Path string `json:"path"` -} - -// ListNamespaces lists any existing namespace relative to the namespace -// provided in the client's namespace header. -func (c *Sys) ListNamespaces() (*ListNamespacesResponse, error) { - r := c.c.NewRequest("LIST", "/v1/sys/namespaces") - - resp, err := c.c.RawRequest(r) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - var result struct { - Data struct { - Keys []string `json:"keys"` - } `json:"data"` - } - err = resp.DecodeJSON(&result) - if err != nil { - return nil, err - } - - return &ListNamespacesResponse{NamespacePaths: result.Data.Keys}, nil -} - -// GetNamespace returns namespace information -func (c *Sys) GetNamespace(path string) (*GetNamespaceResponse, error) { - r := c.c.NewRequest("GET", fmt.Sprintf("/v1/sys/namespaces/%s", path)) - resp, err := c.c.RawRequest(r) - if err != nil { - return nil, err - } - defer resp.Body.Close() - if resp.StatusCode == http.StatusNotFound { - return nil, nil - } - - ret := &GetNamespaceResponse{} - result := map[string]interface{}{ - "data": map[string]interface{}{}, - } - if err := resp.DecodeJSON(&result); err != nil { - return nil, err - } - - if data, ok := result["data"]; ok { - if pathOk, ok := data.(map[string]interface{})["path"]; ok { - if pathRaw, ok := pathOk.(string); ok { - ret.Path = pathRaw - } - } - } - - return ret, nil -} - -// CreateNamespace creates a new namespace relative to the namespace provided -// in the client's namespace header. -func (c *Sys) CreateNamespace(path string) error { - r := c.c.NewRequest("POST", fmt.Sprintf("/v1/sys/namespaces/%s", path)) - resp, err := c.c.RawRequest(r) - if err != nil { - return err - } - defer resp.Body.Close() - - return nil -} - -// DeleteNamespace delete an existing namespace relative to the namespace -// provided in the client's namespace header. -func (c *Sys) DeleteNamespace(path string) error { - r := c.c.NewRequest("DELETE", fmt.Sprintf("/v1/sys/namespaces/%s", path)) - resp, err := c.c.RawRequest(r) - if err != nil { - return err - } - defer resp.Body.Close() - - return nil -} diff --git a/command/base.go b/command/base.go index 99c77b12da..2f98ab9491 100644 --- a/command/base.go +++ b/command/base.go @@ -20,8 +20,13 @@ import ( "github.com/posener/complete" ) -// maxLineLength is the maximum width of any line. -const maxLineLength int = 78 +const ( + // maxLineLength is the maximum width of any line. + maxLineLength int = 78 + + // notSetNamespace is a flag value for a not-set namespace + notSetNamespace = "(not set)" +) // reRemoveWhitespace is a regular expression for stripping whitespace from // a string. @@ -39,6 +44,7 @@ type BaseCommand struct { flagClientCert string flagClientKey string flagNamespace string + flagNS string flagTLSServerName string flagTLSSkipVerify bool flagWrapTTL time.Duration @@ -120,7 +126,12 @@ func (c *BaseCommand) Client() (*api.Client, error) { } client.SetMFACreds(c.flagMFA) - client.SetNamespace(namespace.Canonicalize(c.flagNamespace)) + switch { + case c.flagNS != notSetNamespace: + client.SetNamespace(namespace.Canonicalize(c.flagNS)) + case c.flagNamespace != notSetNamespace: + client.SetNamespace(namespace.Canonicalize(c.flagNamespace)) + } c.client = client @@ -242,11 +253,21 @@ func (c *BaseCommand) flagSet(bit FlagSetBit) *FlagSets { f.StringVar(&StringVar{ Name: "namespace", Target: &c.flagNamespace, - Default: "", + Default: notSetNamespace, // this can never be a real value EnvVar: "VAULT_NAMESPACE", Completion: complete.PredictAnything, Usage: "The namespace to use for the command. Setting this is not " + - "necessary but allows using relative paths.", + "necessary but allows using relative paths. -ns can be used as " + + "shortcut.", + }) + + f.StringVar(&StringVar{ + Name: "ns", + Target: &c.flagNS, + Default: notSetNamespace, // this can never be a real value + Completion: complete.PredictAnything, + Hidden: true, + Usage: "Alias for -namespace.", }) f.StringVar(&StringVar{ diff --git a/command/base_flags.go b/command/base_flags.go index cc5bfbd318..4723170370 100644 --- a/command/base_flags.go +++ b/command/base_flags.go @@ -43,7 +43,7 @@ type BoolVar struct { func (f *FlagSet) BoolVar(i *BoolVar) { def := i.Default - if v := os.Getenv(i.EnvVar); v != "" { + if v, exist := os.LookupEnv(i.EnvVar); exist { if b, err := strconv.ParseBool(v); err == nil { def = b } @@ -104,7 +104,7 @@ type IntVar struct { func (f *FlagSet) IntVar(i *IntVar) { initial := i.Default - if v := os.Getenv(i.EnvVar); v != "" { + if v, exist := os.LookupEnv(i.EnvVar); exist { if i, err := strconv.ParseInt(v, 0, 64); err == nil { initial = int(i) } @@ -168,7 +168,7 @@ type Int64Var struct { func (f *FlagSet) Int64Var(i *Int64Var) { initial := i.Default - if v := os.Getenv(i.EnvVar); v != "" { + if v, exist := os.LookupEnv(i.EnvVar); exist { if i, err := strconv.ParseInt(v, 0, 64); err == nil { initial = i } @@ -232,7 +232,7 @@ type UintVar struct { func (f *FlagSet) UintVar(i *UintVar) { initial := i.Default - if v := os.Getenv(i.EnvVar); v != "" { + if v, exist := os.LookupEnv(i.EnvVar); exist { if i, err := strconv.ParseUint(v, 0, 64); err == nil { initial = uint(i) } @@ -296,7 +296,7 @@ type Uint64Var struct { func (f *FlagSet) Uint64Var(i *Uint64Var) { initial := i.Default - if v := os.Getenv(i.EnvVar); v != "" { + if v, exist := os.LookupEnv(i.EnvVar); exist { if i, err := strconv.ParseUint(v, 0, 64); err == nil { initial = i } @@ -360,7 +360,7 @@ type StringVar struct { func (f *FlagSet) StringVar(i *StringVar) { initial := i.Default - if v := os.Getenv(i.EnvVar); v != "" { + if v, exist := os.LookupEnv(i.EnvVar); exist { initial = v } @@ -417,7 +417,7 @@ type Float64Var struct { func (f *FlagSet) Float64Var(i *Float64Var) { initial := i.Default - if v := os.Getenv(i.EnvVar); v != "" { + if v, exist := os.LookupEnv(i.EnvVar); exist { if i, err := strconv.ParseFloat(v, 64); err == nil { initial = i } @@ -481,7 +481,7 @@ type DurationVar struct { func (f *FlagSet) DurationVar(i *DurationVar) { initial := i.Default - if v := os.Getenv(i.EnvVar); v != "" { + if v, exist := os.LookupEnv(i.EnvVar); exist { if d, err := time.ParseDuration(appendDurationSuffix(v)); err == nil { initial = d } @@ -558,7 +558,7 @@ type StringSliceVar struct { func (f *FlagSet) StringSliceVar(i *StringSliceVar) { initial := i.Default - if v := os.Getenv(i.EnvVar); v != "" { + if v, exist := os.LookupEnv(i.EnvVar); exist { parts := strings.Split(v, ",") for i := range parts { parts[i] = strings.TrimSpace(parts[i]) @@ -751,14 +751,14 @@ func (f *FlagSet) Var(value flag.Value, name, usage string) { // -- helpers func envDefault(key, def string) string { - if v := os.Getenv(key); v != "" { + if v, exist := os.LookupEnv(key); exist { return v } return def } func envBoolDefault(key string, def bool) bool { - if v := os.Getenv(key); v != "" { + if v, exist := os.LookupEnv(key); exist { b, err := strconv.ParseBool(v) if err != nil { panic(err) @@ -769,7 +769,7 @@ func envBoolDefault(key string, def bool) bool { } func envDurationDefault(key string, def time.Duration) time.Duration { - if v := os.Getenv(key); v != "" { + if v, exist := os.LookupEnv(key); exist { d, err := time.ParseDuration(v) if err != nil { panic(err) diff --git a/command/namespace.go b/command/namespace.go index 1f232ef157..8adfc66ef2 100644 --- a/command/namespace.go +++ b/command/namespace.go @@ -21,9 +21,9 @@ func (c *NamespaceCommand) Help() string { Usage: vault namespace [options] [args] This command groups subcommands for interacting with Vault namespaces. - These set of subcommands operate on the context of the namespace that the - current logged in token belongs to. - + These subcommands operate in the context of the namespace that the + currently logged in token belongs to. + List enabled child namespaces: $ vault namespace list diff --git a/command/namespace_create.go b/command/namespace_create.go index 3097215c5b..17e2d34e07 100644 --- a/command/namespace_create.go +++ b/command/namespace_create.go @@ -79,14 +79,19 @@ func (c *NamespaceCreateCommand) Run(args []string) int { return 2 } - err = client.Sys().CreateNamespace(namespacePath) + _, err = client.Logical().Write("sys/namespaces/"+namespacePath, nil) if err != nil { c.UI.Error(fmt.Sprintf("Error creating namespace: %s", err)) return 2 } + if !strings.HasSuffix(namespacePath, "/") { + namespacePath = namespacePath + "/" + } + if c.flagNamespace != notSetNamespace { + namespacePath = path.Join(c.flagNamespace, namespacePath) + } // Output full path - fullPath := path.Join(c.flagNamespace, namespacePath) + "/" - c.UI.Output(fmt.Sprintf("Success! Namespace created at: %s", fullPath)) + c.UI.Output(fmt.Sprintf("Success! Namespace created at: %s", namespacePath)) return 0 } diff --git a/command/namespace_delete.go b/command/namespace_delete.go index 3565d1b816..416f0f31bc 100644 --- a/command/namespace_delete.go +++ b/command/namespace_delete.go @@ -79,14 +79,24 @@ func (c *NamespaceDeleteCommand) Run(args []string) int { return 2 } - err = client.Sys().DeleteNamespace(namespacePath) + secret, err := client.Logical().Delete("sys/namespaces/" + namespacePath) if err != nil { c.UI.Error(fmt.Sprintf("Error deleting namespace: %s", err)) return 2 } - // Output full path - fullPath := path.Join(c.flagNamespace, namespacePath) + "/" - c.UI.Output(fmt.Sprintf("Success! Namespace deleted at: %s", fullPath)) + if secret != nil { + // Likely, we have warnings + return OutputSecret(c.UI, secret) + } + + if !strings.HasSuffix(namespacePath, "/") { + namespacePath = namespacePath + "/" + } + if c.flagNamespace != notSetNamespace { + namespacePath = path.Join(c.flagNamespace, namespacePath) + } + + c.UI.Output(fmt.Sprintf("Success! Namespace deleted at: %s", namespacePath)) return 0 } diff --git a/command/namespace_list.go b/command/namespace_list.go index 5ba3248faf..893e1a76e4 100644 --- a/command/namespace_list.go +++ b/command/namespace_list.go @@ -66,19 +66,29 @@ func (c *NamespaceListCommand) Run(args []string) int { return 2 } - namespaces, err := client.Sys().ListNamespaces() + secret, err := client.Logical().List("sys/namespaces") if err != nil { c.UI.Error(fmt.Sprintf("Error listing namespaces: %s", err)) return 2 } - - switch Format(c.UI) { - case "table": - for _, ns := range namespaces.NamespacePaths { - c.UI.Output(ns) - } - return 0 - default: - return OutputData(c.UI, namespaces) + if secret == nil { + c.UI.Error(fmt.Sprintf("No namespaces found")) + return 2 } + + // There could be e.g. warnings + if secret.Data == nil { + return OutputSecret(c.UI, secret) + } + + if secret.WrapInfo != nil && secret.WrapInfo.TTL != 0 { + return OutputSecret(c.UI, secret) + } + + if _, ok := extractListData(secret); !ok { + c.UI.Error(fmt.Sprintf("No entries found")) + return 2 + } + + return OutputList(c.UI, secret) } diff --git a/command/namespace_lookup.go b/command/namespace_lookup.go index 0e913becfa..c025a4e8fd 100644 --- a/command/namespace_lookup.go +++ b/command/namespace_lookup.go @@ -16,14 +16,14 @@ type NamespaceLookupCommand struct { } func (c *NamespaceLookupCommand) Synopsis() string { - return "Create a new namespace" + return "Look up an existing namespace" } func (c *NamespaceLookupCommand) Help() string { helpText := ` Usage: vault namespace create [options] PATH - Create a child namespace. The namespace created will be relative to the + Create a child namespace. The namespace created will be relative to the namespace provided in either VAULT_NAMESPACE environemnt variable or -namespace CLI flag. @@ -33,7 +33,7 @@ Usage: vault namespace create [options] PATH Get information about the namespace of a particular child token (e.g. ns1/ns2/): - $ vault namespace create -namespace=ns1 ns2 + $ vault namespace lookup -namespace=ns1 ns2 ` + c.Flags().Help() @@ -78,19 +78,15 @@ func (c *NamespaceLookupCommand) Run(args []string) int { return 2 } - resp, err := client.Sys().GetNamespace(namespacePath) + secret, err := client.Logical().Read("sys/namespaces/" + namespacePath) if err != nil { c.UI.Error(fmt.Sprintf("Error looking up namespace: %s", err)) return 2 } - - switch Format(c.UI) { - case "table": - data := map[string]interface{}{ - "path": resp.Path, - } - return OutputData(c.UI, data) - default: - return OutputData(c.UI, resp) + if secret == nil { + c.UI.Error("Namespace not found") + return 2 } + + return OutputSecret(c.UI, secret) } diff --git a/command/operator_generate_root.go b/command/operator_generate_root.go index 1eed1bf22a..a8f098c0a7 100644 --- a/command/operator_generate_root.go +++ b/command/operator_generate_root.go @@ -12,6 +12,7 @@ import ( "github.com/hashicorp/errwrap" "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/helper/base62" "github.com/hashicorp/vault/helper/password" "github.com/hashicorp/vault/helper/pgpkeys" "github.com/hashicorp/vault/helper/xor" @@ -123,8 +124,7 @@ func (c *OperatorGenerateRootCommand) Flags() *FlagSets { Default: "", EnvVar: "", Completion: complete.PredictAnything, - Usage: "Decode and output the generated root token. This option requires " + - "the \"-otp\" flag be set to the OTP used during initialization.", + Usage: "The value to decode; setting this triggers a decode operation.", }) f.BoolVar(&BoolVar{ @@ -233,9 +233,13 @@ func (c *OperatorGenerateRootCommand) Run(args []string) int { switch { case c.flagGenerateOTP: - return c.generateOTP() + otp, code := c.generateOTP(client, c.flagDRToken) + if code == 0 { + return PrintRaw(c.UI, otp) + } + return code case c.flagDecode != "": - return c.decode(c.flagDecode, c.flagOTP) + return c.decode(client, c.flagDecode, c.flagOTP, c.flagDRToken) case c.flagCancel: return c.cancel(client, c.flagDRToken) case c.flagInit: @@ -252,41 +256,48 @@ func (c *OperatorGenerateRootCommand) Run(args []string) int { } } -// verifyOTP verifies the given OTP code is exactly 16 bytes. -func (c *OperatorGenerateRootCommand) verifyOTP(otp string) error { - if len(otp) == 0 { - return fmt.Errorf("no OTP passed in") - } - otpBytes, err := base64.StdEncoding.DecodeString(otp) - if err != nil { - return errwrap.Wrapf("error decoding base64 OTP value: {{err}}", err) - } - if otpBytes == nil || len(otpBytes) != 16 { - return fmt.Errorf("decoded OTP value is invalid or wrong length") - } - - return nil -} - // generateOTP generates a suitable OTP code for generating a root token. -func (c *OperatorGenerateRootCommand) generateOTP() int { - buf := make([]byte, 16) - readLen, err := rand.Read(buf) +func (c *OperatorGenerateRootCommand) generateOTP(client *api.Client, drToken bool) (string, int) { + f := client.Sys().GenerateRootStatus + if drToken { + f = client.Sys().GenerateDROperationTokenStatus + } + status, err := f() if err != nil { - c.UI.Error(fmt.Sprintf("Error reading random bytes: %s", err)) - return 2 + c.UI.Error(fmt.Sprintf("Error getting root generation status: %s", err)) + return "", 2 } - if readLen != 16 { - c.UI.Error(fmt.Sprintf("Read %d bytes when we should have read 16", readLen)) - return 2 - } + switch status.OTPLength { + case 0: + // This is the fallback case + buf := make([]byte, 16) + readLen, err := rand.Read(buf) + if err != nil { + c.UI.Error(fmt.Sprintf("Error reading random bytes: %s", err)) + return "", 2 + } - return PrintRaw(c.UI, base64.StdEncoding.EncodeToString(buf)) + if readLen != 16 { + c.UI.Error(fmt.Sprintf("Read %d bytes when we should have read 16", readLen)) + return "", 2 + } + + return base64.StdEncoding.EncodeToString(buf), 0 + + default: + otp, err := base62.Random(status.OTPLength, true) + if err != nil { + c.UI.Error(errwrap.Wrapf("Error reading random bytes: {{err}}", err).Error()) + return "", 2 + } + + return otp, 0 + } } // decode decodes the given value using the otp. -func (c *OperatorGenerateRootCommand) decode(encoded, otp string) int { +func (c *OperatorGenerateRootCommand) decode(client *api.Client, encoded, otp string, drToken bool) int { if encoded == "" { c.UI.Error("Missing encoded value: use -decode= to supply it") return 1 @@ -296,38 +307,56 @@ func (c *OperatorGenerateRootCommand) decode(encoded, otp string) int { return 1 } - tokenBytes, err := xor.XORBase64(encoded, otp) + f := client.Sys().GenerateRootStatus + if drToken { + f = client.Sys().GenerateDROperationTokenStatus + } + status, err := f() if err != nil { - c.UI.Error(fmt.Sprintf("Error xoring token: %s", err)) - return 1 + c.UI.Error(fmt.Sprintf("Error getting root generation status: %s", err)) + return 2 } - token, err := uuid.FormatUUID(tokenBytes) - if err != nil { - c.UI.Error(fmt.Sprintf("Error formatting base64 token value: %s", err)) - return 1 - } + switch status.OTPLength { + case 0: + // Backwards compat + tokenBytes, err := xor.XORBase64(encoded, otp) + if err != nil { + c.UI.Error(fmt.Sprintf("Error xoring token: %s", err)) + return 1 + } - return PrintRaw(c.UI, strings.TrimSpace(token)) + token, err := uuid.FormatUUID(tokenBytes) + if err != nil { + c.UI.Error(fmt.Sprintf("Error formatting base64 token value: %s", err)) + return 1 + } + + return PrintRaw(c.UI, strings.TrimSpace(token)) + + default: + tokenBytes, err := base64.RawStdEncoding.DecodeString(encoded) + if err != nil { + c.UI.Error(errwrap.Wrapf("Error decoding base64'd token: {{err}}", err).Error()) + return 1 + } + + tokenBytes, err = xor.XORBytes(tokenBytes, []byte(otp)) + if err != nil { + c.UI.Error(errwrap.Wrapf("Error xoring token: {{err}}", err).Error()) + return 1 + } + + return PrintRaw(c.UI, string(tokenBytes)) + } } // init is used to start the generation process func (c *OperatorGenerateRootCommand) init(client *api.Client, otp, pgpKey string, drToken bool) int { // Validate incoming fields. Either OTP OR PGP keys must be supplied. - switch { - case otp == "" && pgpKey == "": - c.UI.Error("Error initializing: must specify either -otp or -pgp-key") - return 1 - case otp != "" && pgpKey != "": + if otp != "" && pgpKey != "" { c.UI.Error("Error initializing: cannot specify both -otp and -pgp-key") return 1 - case otp != "": - if err := c.verifyOTP(otp); err != nil { - c.UI.Error(fmt.Sprintf("Error initializing: invalid OTP: %s", err)) - return 1 - } - case pgpKey != "": - // OK } // Start the root generation @@ -368,6 +397,10 @@ func (c *OperatorGenerateRootCommand) provide(client *api.Client, key string, dr c.UI.Error(wrapAtLength( "No root generation is in progress. Start a root generation by " + "running \"vault operator generate-root -init\".")) + c.UI.Warn(wrapAtLength(fmt.Sprintf( + "If starting root generation using the OTP method and generating "+ + "your own OTP, the length of the OTP string needs to be %d "+ + "characters in length.", status.OTPLength))) return 1 } @@ -494,6 +527,13 @@ func (c *OperatorGenerateRootCommand) printStatus(status *api.GenerateRootStatus case status.EncodedRootToken != "": out = append(out, fmt.Sprintf("Encoded Root Token | %s", status.EncodedRootToken)) } + if status.OTP != "" { + c.UI.Warn(wrapAtLength("A One-Time-Password has been generated for you and is shown in the OTP field. You will need this value to decode the resulting root token, so keep it safe.")) + out = append(out, fmt.Sprintf("OTP | %s", status.OTP)) + } + if status.OTPLength != 0 { + out = append(out, fmt.Sprintf("OTP Length | %d", status.OTPLength)) + } output := columnOutput(out, nil) c.UI.Output(output) diff --git a/command/operator_generate_root_test.go b/command/operator_generate_root_test.go index ecbb6c16cd..f8601183ab 100644 --- a/command/operator_generate_root_test.go +++ b/command/operator_generate_root_test.go @@ -34,22 +34,14 @@ func TestOperatorGenerateRootCommand_Run(t *testing.T) { out string code int }{ - { - "init_no_args", - []string{ - "-init", - }, - "must specify either -otp or -pgp-key", - 1, - }, { "init_invalid_otp", []string{ "-init", "-otp", "not-a-valid-otp", }, - "Error initializing: invalid OTP:", - 1, + "illegal base64 data at input", + 2, }, { "init_pgp_multi", @@ -99,12 +91,12 @@ func TestOperatorGenerateRootCommand_Run(t *testing.T) { code := cmd.Run(tc.args) if code != tc.code { - t.Errorf("expected %d to be %d", code, tc.code) + t.Errorf("%s: expected %d to be %d", tc.name, code, tc.code) } combined := ui.OutputWriter.String() + ui.ErrorWriter.String() if !strings.Contains(combined, tc.out) { - t.Errorf("expected %q to contain %q", combined, tc.out) + t.Errorf("%s: expected %q to contain %q", tc.name, combined, tc.out) } }) } @@ -116,7 +108,7 @@ func TestOperatorGenerateRootCommand_Run(t *testing.T) { client, closer := testVaultServer(t) defer closer() - ui, cmd := testOperatorGenerateRootCommand(t) + _, cmd := testOperatorGenerateRootCommand(t) cmd.client = client code := cmd.Run([]string{ @@ -125,11 +117,6 @@ func TestOperatorGenerateRootCommand_Run(t *testing.T) { if exp := 0; code != exp { t.Errorf("expected %d to be %d", code, exp) } - - output := ui.OutputWriter.String() + ui.ErrorWriter.String() - if err := cmd.verifyOTP(output); err != nil { - t.Fatal(err) - } }) t.Run("decode", func(t *testing.T) { diff --git a/helper/consts/consts.go b/helper/consts/consts.go index eee59d9c99..972a69f47b 100644 --- a/helper/consts/consts.go +++ b/helper/consts/consts.go @@ -4,4 +4,11 @@ const ( // ExpirationRestoreWorkerCount specifies the number of workers to use while // restoring leases into the expiration manager ExpirationRestoreWorkerCount = 64 + + // NamespaceHeaderName is the header set to specify which namespace the + // request is indented for. + NamespaceHeaderName = "X-Vault-Namespace" + + // AuthHeaderName is the name of the header containing the token. + AuthHeaderName = "X-Vault-Token" ) diff --git a/helper/namespace/namespace.go b/helper/namespace/namespace.go index 1c57040a97..a7855cf7be 100644 --- a/helper/namespace/namespace.go +++ b/helper/namespace/namespace.go @@ -3,6 +3,7 @@ package namespace import ( "context" "errors" + "net/http" "strings" ) @@ -16,6 +17,11 @@ type nsContext struct { type contextValues struct{} +type Namespace struct { + ID string `json:"id"` + Path string `json:"path"` +} + const ( RootNamespaceID = "root" ) @@ -23,18 +29,14 @@ const ( var ( contextNamespace contextValues = struct{}{} ErrNoNamespace error = errors.New("no namespace") + RootNamespace *Namespace = &Namespace{ + ID: RootNamespaceID, + Path: "", + } ) -type Namespace struct { - ID string `json:"id"` - Path string `json:"path"` -} - -func New(id, path string) *Namespace { - return &Namespace{ - ID: id, - Path: path, - } +var AdjustRequest = func(r *http.Request) (*http.Request, int) { + return r.WithContext(ContextWithNamespace(r.Context(), RootNamespace)), 0 } func (n *Namespace) HasParent(possibleParent *Namespace) bool { @@ -60,6 +62,18 @@ func ContextWithNamespace(ctx context.Context, ns *Namespace) context.Context { } } +func RootContext(ctx context.Context) context.Context { + if ctx == nil { + return ContextWithNamespace(context.Background(), RootNamespace) + } + return ContextWithNamespace(ctx, RootNamespace) +} + +// This function caches the ns to avoid doing a .Value lookup over and over, +// because it's called a *lot* in the request critical path. .Value is +// concurrency-safe so uses some kind of locking/atomicity, but it should never +// be read before first write, plus we don't believe this will be called from +// different goroutines, so it should be safe. func FromContext(ctx context.Context) (*Namespace, error) { if ctx == nil { return nil, errors.New("context was nil") @@ -72,20 +86,28 @@ func FromContext(ctx context.Context) (*Namespace, error) { } } - ns := ctx.Value(contextNamespace) + nsRaw := ctx.Value(contextNamespace) + if nsRaw == nil { + return nil, ErrNoNamespace + } + + ns := nsRaw.(*Namespace) if ns == nil { return nil, ErrNoNamespace } if ok { - nsCtx.cachedNS = ns.(*Namespace) + nsCtx.cachedNS = ns } - - return ns.(*Namespace), nil + return ns, nil } func TestContext() context.Context { - return ContextWithNamespace(context.Background(), New(RootNamespaceID, "")) + return ContextWithNamespace(context.Background(), TestNamespace()) +} + +func TestNamespace() *Namespace { + return RootNamespace } // Canonicalize trims any prefix '/' and adds a trailing '/' to the