mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-10-31 10:37:56 +00:00 
			
		
		
		
	Expand and centralize helpers
This commit is contained in:
		| @@ -2,35 +2,72 @@ package command | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"strings" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/hashicorp/vault/api" | ||||
| 	kvbuilder "github.com/hashicorp/vault/helper/kv-builder" | ||||
| 	homedir "github.com/mitchellh/go-homedir" | ||||
| 	"github.com/mitchellh/mapstructure" | ||||
| 	"github.com/pkg/errors" | ||||
| 	"github.com/ryanuber/columnize" | ||||
| ) | ||||
|  | ||||
| var ErrMissingID = fmt.Errorf("Missing ID!") | ||||
| var ErrMissingPath = fmt.Errorf("Missing PATH!") | ||||
| var ErrMissingThing = fmt.Errorf("Missing THING!") | ||||
|  | ||||
| // extractListData reads the secret and returns a typed list of data and a | ||||
| // boolean indicating whether the extraction was successful. | ||||
| func extractListData(secret *api.Secret) ([]interface{}, bool) { | ||||
| 	if secret == nil || secret.Data == nil { | ||||
| 		return nil, false | ||||
| 	} | ||||
|  | ||||
| 	k, ok := secret.Data["keys"] | ||||
| 	if !ok || k == nil { | ||||
| 		return nil, false | ||||
| 	} | ||||
|  | ||||
| 	i, ok := k.([]interface{}) | ||||
| 	return i, ok | ||||
| } | ||||
|  | ||||
| // extractPath extracts the path and list of arguments from the args. If there | ||||
| // are no extra arguments, the remaining args will be nil. | ||||
| func extractPath(args []string) (string, []string, error) { | ||||
| 	str, remaining, err := extractThings(args) | ||||
| 	if err == ErrMissingThing { | ||||
| 		err = ErrMissingPath | ||||
| 	} | ||||
| 	return str, remaining, err | ||||
| } | ||||
|  | ||||
| // extractID extracts the path and list of arguments from the args. If there | ||||
| // are no extra arguments, the remaining args will be nil. | ||||
| func extractID(args []string) (string, []string, error) { | ||||
| 	str, remaining, err := extractThings(args) | ||||
| 	if err == ErrMissingThing { | ||||
| 		err = ErrMissingID | ||||
| 	} | ||||
| 	return str, remaining, err | ||||
| } | ||||
|  | ||||
| func extractThings(args []string) (string, []string, error) { | ||||
| 	if len(args) < 1 { | ||||
| 		return "", nil, ErrMissingPath | ||||
| 		return "", nil, ErrMissingThing | ||||
| 	} | ||||
|  | ||||
| 	// Path is always the first argument after all flags | ||||
| 	path := args[0] | ||||
| 	thing := args[0] | ||||
|  | ||||
| 	// Strip leading and trailing slashes | ||||
| 	for len(path) > 0 && path[0] == '/' { | ||||
| 		path = path[1:] | ||||
| 	} | ||||
| 	for len(path) > 0 && path[len(path)-1] == '/' { | ||||
| 		path = path[:len(path)-1] | ||||
| 	} | ||||
| 	thing = sanitizePath(thing) | ||||
|  | ||||
| 	// Trim any leading/trailing whitespace | ||||
| 	path = strings.TrimSpace(path) | ||||
|  | ||||
| 	// Verify we have a path | ||||
| 	if path == "" { | ||||
| 		return "", nil, ErrMissingPath | ||||
| 	// Verify we have a thing | ||||
| 	if thing == "" { | ||||
| 		return "", nil, ErrMissingThing | ||||
| 	} | ||||
|  | ||||
| 	// Splice remaining args | ||||
| @@ -39,5 +76,150 @@ func extractPath(args []string) (string, []string, error) { | ||||
| 		remaining = args[1:] | ||||
| 	} | ||||
|  | ||||
| 	return path, remaining, nil | ||||
| 	return thing, remaining, nil | ||||
| } | ||||
|  | ||||
| // sanitizePath removes any leading or trailing things from a "path". | ||||
| func sanitizePath(s string) string { | ||||
| 	return ensureNoTrailingSlash(ensureNoLeadingSlash(s)) | ||||
| } | ||||
|  | ||||
| // ensureTrailingSlash ensures the given string has a trailing slash. | ||||
| func ensureTrailingSlash(s string) string { | ||||
| 	s = strings.TrimSpace(s) | ||||
| 	if s == "" { | ||||
| 		return "" | ||||
| 	} | ||||
|  | ||||
| 	for len(s) > 0 && s[len(s)-1] != '/' { | ||||
| 		s = s + "/" | ||||
| 	} | ||||
| 	return s | ||||
| } | ||||
|  | ||||
| // ensureNoTrailingSlash ensures the given string has a trailing slash. | ||||
| func ensureNoTrailingSlash(s string) string { | ||||
| 	s = strings.TrimSpace(s) | ||||
| 	if s == "" { | ||||
| 		return "" | ||||
| 	} | ||||
|  | ||||
| 	for len(s) > 0 && s[len(s)-1] == '/' { | ||||
| 		s = s[:len(s)-1] | ||||
| 	} | ||||
| 	return s | ||||
| } | ||||
|  | ||||
| // ensureNoLeadingSlash ensures the given string has a trailing slash. | ||||
| func ensureNoLeadingSlash(s string) string { | ||||
| 	s = strings.TrimSpace(s) | ||||
| 	if s == "" { | ||||
| 		return "" | ||||
| 	} | ||||
|  | ||||
| 	for len(s) > 0 && s[0] == '/' { | ||||
| 		s = s[1:] | ||||
| 	} | ||||
| 	return s | ||||
| } | ||||
|  | ||||
| // columnOuput prints the list of items as a table with no headers. | ||||
| func columnOutput(list []string) string { | ||||
| 	if len(list) == 0 { | ||||
| 		return "" | ||||
| 	} | ||||
|  | ||||
| 	return columnize.Format(list, &columnize.Config{ | ||||
| 		Glue:  "    ", | ||||
| 		Empty: "n/a", | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| // tableOutput prints the list of items as columns, where the first row is | ||||
| // the list of headers. | ||||
| func tableOutput(list []string) string { | ||||
| 	if len(list) == 0 { | ||||
| 		return "" | ||||
| 	} | ||||
|  | ||||
| 	underline := "" | ||||
| 	headers := strings.Split(list[0], "|") | ||||
| 	for i, h := range headers { | ||||
| 		h = strings.TrimSpace(h) | ||||
| 		u := strings.Repeat("-", len(h)) | ||||
|  | ||||
| 		underline = underline + u | ||||
| 		if i != len(headers)-1 { | ||||
| 			underline = underline + " | " | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	list = append(list, "") | ||||
| 	copy(list[2:], list[1:]) | ||||
| 	list[1] = underline | ||||
|  | ||||
| 	return columnOutput(list) | ||||
| } | ||||
|  | ||||
| // parseArgsData parses the given args in the format key=value into a map of | ||||
| // the provided arguments. The given reader can also supply key=value pairs. | ||||
| func parseArgsData(stdin io.Reader, args []string) (map[string]interface{}, error) { | ||||
| 	builder := &kvbuilder.Builder{Stdin: stdin} | ||||
| 	if err := builder.Add(args...); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	return builder.Map(), nil | ||||
| } | ||||
|  | ||||
| // parseArgsDataString parses the args data and returns the values as strings. | ||||
| // If the values cannot be represented as strings, an error is returned. | ||||
| func parseArgsDataString(stdin io.Reader, args []string) (map[string]string, error) { | ||||
| 	raw, err := parseArgsData(stdin, args) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	var result map[string]string | ||||
| 	if err := mapstructure.WeakDecode(raw, &result); err != nil { | ||||
| 		return nil, errors.Wrap(err, "failed to convert values to strings") | ||||
| 	} | ||||
| 	return result, nil | ||||
| } | ||||
|  | ||||
| // truncateToSeconds truncates the given duaration to the number of seconds. If | ||||
| // the duration is less than 1s, it is returned as 0. The integer represents | ||||
| // the whole number unit of seconds for the duration. | ||||
| func truncateToSeconds(d time.Duration) int { | ||||
| 	d = d.Truncate(1 * time.Second) | ||||
|  | ||||
| 	// Handle the case where someone requested a ridiculously short increment - | ||||
| 	// incremenents must be larger than a second. | ||||
| 	if d < 1*time.Second { | ||||
| 		return 0 | ||||
| 	} | ||||
|  | ||||
| 	return int(d.Seconds()) | ||||
| } | ||||
|  | ||||
| // printKeyStatus prints the KeyStatus response from the API. | ||||
| func printKeyStatus(ks *api.KeyStatus) string { | ||||
| 	return columnOutput([]string{ | ||||
| 		fmt.Sprintf("Key Term | %d", ks.Term), | ||||
| 		fmt.Sprintf("Install Time | %s", ks.InstallTime.UTC().Format(time.RFC822)), | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| // expandPath takes a filepath and returns the full expanded path, accounting | ||||
| // for user-relative things like ~/. | ||||
| func expandPath(s string) string { | ||||
| 	if s == "" { | ||||
| 		return "" | ||||
| 	} | ||||
|  | ||||
| 	e, err := homedir.Expand(s) | ||||
| 	if err != nil { | ||||
| 		return s | ||||
| 	} | ||||
| 	return e | ||||
| } | ||||
|   | ||||
							
								
								
									
										162
									
								
								command/base_helpers_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										162
									
								
								command/base_helpers_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,162 @@ | ||||
| package command | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"io/ioutil" | ||||
| 	"os" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| func TestParseArgsData(t *testing.T) { | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	t.Run("stdin_full", func(t *testing.T) { | ||||
| 		t.Parallel() | ||||
|  | ||||
| 		stdinR, stdinW := io.Pipe() | ||||
| 		go func() { | ||||
| 			stdinW.Write([]byte(`{"foo":"bar"}`)) | ||||
| 			stdinW.Close() | ||||
| 		}() | ||||
|  | ||||
| 		m, err := parseArgsData(stdinR, []string{"-"}) | ||||
| 		if err != nil { | ||||
| 			t.Fatal(err) | ||||
| 		} | ||||
|  | ||||
| 		if v, ok := m["foo"]; !ok || v != "bar" { | ||||
| 			t.Errorf("expected %q to be %q", v, "bar") | ||||
| 		} | ||||
| 	}) | ||||
|  | ||||
| 	t.Run("stdin_value", func(t *testing.T) { | ||||
| 		t.Parallel() | ||||
|  | ||||
| 		stdinR, stdinW := io.Pipe() | ||||
| 		go func() { | ||||
| 			stdinW.Write([]byte(`bar`)) | ||||
| 			stdinW.Close() | ||||
| 		}() | ||||
|  | ||||
| 		m, err := parseArgsData(stdinR, []string{"foo=-"}) | ||||
| 		if err != nil { | ||||
| 			t.Fatal(err) | ||||
| 		} | ||||
|  | ||||
| 		if v, ok := m["foo"]; !ok || v != "bar" { | ||||
| 			t.Errorf("expected %q to be %q", v, "bar") | ||||
| 		} | ||||
| 	}) | ||||
|  | ||||
| 	t.Run("file_full", func(t *testing.T) { | ||||
| 		t.Parallel() | ||||
|  | ||||
| 		f, err := ioutil.TempFile("", "vault") | ||||
| 		if err != nil { | ||||
| 			t.Fatal(err) | ||||
| 		} | ||||
| 		f.Write([]byte(`{"foo":"bar"}`)) | ||||
| 		f.Close() | ||||
| 		defer os.Remove(f.Name()) | ||||
|  | ||||
| 		m, err := parseArgsData(os.Stdin, []string{"@" + f.Name()}) | ||||
| 		if err != nil { | ||||
| 			t.Fatal(err) | ||||
| 		} | ||||
|  | ||||
| 		if v, ok := m["foo"]; !ok || v != "bar" { | ||||
| 			t.Errorf("expected %q to be %q", v, "bar") | ||||
| 		} | ||||
| 	}) | ||||
|  | ||||
| 	t.Run("file_value", func(t *testing.T) { | ||||
| 		t.Parallel() | ||||
|  | ||||
| 		f, err := ioutil.TempFile("", "vault") | ||||
| 		if err != nil { | ||||
| 			t.Fatal(err) | ||||
| 		} | ||||
| 		f.Write([]byte(`bar`)) | ||||
| 		f.Close() | ||||
| 		defer os.Remove(f.Name()) | ||||
|  | ||||
| 		m, err := parseArgsData(os.Stdin, []string{"foo=@" + f.Name()}) | ||||
| 		if err != nil { | ||||
| 			t.Fatal(err) | ||||
| 		} | ||||
|  | ||||
| 		if v, ok := m["foo"]; !ok || v != "bar" { | ||||
| 			t.Errorf("expected %q to be %q", v, "bar") | ||||
| 		} | ||||
| 	}) | ||||
|  | ||||
| 	t.Run("file_value_escaped", func(t *testing.T) { | ||||
| 		t.Parallel() | ||||
|  | ||||
| 		m, err := parseArgsData(os.Stdin, []string{`foo=\@`}) | ||||
| 		if err != nil { | ||||
| 			t.Fatal(err) | ||||
| 		} | ||||
|  | ||||
| 		if v, ok := m["foo"]; !ok || v != "@" { | ||||
| 			t.Errorf("expected %q to be %q", v, "@") | ||||
| 		} | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func TestTruncateToSeconds(t *testing.T) { | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	cases := []struct { | ||||
| 		d   time.Duration | ||||
| 		exp int | ||||
| 	}{ | ||||
| 		{ | ||||
| 			10 * time.Nanosecond, | ||||
| 			0, | ||||
| 		}, | ||||
| 		{ | ||||
| 			10 * time.Microsecond, | ||||
| 			0, | ||||
| 		}, | ||||
| 		{ | ||||
| 			10 * time.Millisecond, | ||||
| 			0, | ||||
| 		}, | ||||
| 		{ | ||||
| 			1 * time.Second, | ||||
| 			1, | ||||
| 		}, | ||||
| 		{ | ||||
| 			10 * time.Second, | ||||
| 			10, | ||||
| 		}, | ||||
| 		{ | ||||
| 			100 * time.Second, | ||||
| 			100, | ||||
| 		}, | ||||
| 		{ | ||||
| 			3 * time.Minute, | ||||
| 			180, | ||||
| 		}, | ||||
| 		{ | ||||
| 			3 * time.Hour, | ||||
| 			10800, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for _, tc := range cases { | ||||
| 		tc := tc | ||||
|  | ||||
| 		t.Run(fmt.Sprintf("%s", tc.d), func(t *testing.T) { | ||||
| 			t.Parallel() | ||||
|  | ||||
| 			act := truncateToSeconds(tc.d) | ||||
| 			if act != tc.exp { | ||||
| 				t.Errorf("expected %d to be %d", act, tc.exp) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
		Reference in New Issue
	
	Block a user
	 Seth Vargo
					Seth Vargo