From 574050b53f8ef36c975388f9d144080da7764bef Mon Sep 17 00:00:00 2001 From: Mitchell Hashimoto Date: Tue, 7 Apr 2015 22:30:25 -0700 Subject: [PATCH] helper/kv-builder --- command/write.go | 75 +++----------------- helper/kv-builder/builder.go | 114 ++++++++++++++++++++++++++++++ helper/kv-builder/builder_test.go | 86 ++++++++++++++++++++++ 3 files changed, 211 insertions(+), 64 deletions(-) create mode 100644 helper/kv-builder/builder.go create mode 100644 helper/kv-builder/builder_test.go diff --git a/command/write.go b/command/write.go index cfa5379a97..e2d72958e7 100644 --- a/command/write.go +++ b/command/write.go @@ -1,12 +1,12 @@ package command import ( - "encoding/json" "fmt" "io" - "io/ioutil" "os" "strings" + + "github.com/hashicorp/vault/helper/kv-builder" ) // WriteCommand is a Command that puts data into the Vault. @@ -61,70 +61,17 @@ func (c *WriteCommand) Run(args []string) int { } func (c *WriteCommand) parseData(args []string) (map[string]interface{}, error) { - result := make(map[string]interface{}) - - for i, arg := range args { - // If the arg is exactly "-" then we read from stdin and merge - // the resulting structure into the result. - if arg == "-" { - var stdin io.Reader = os.Stdin - if c.testStdin != nil { - stdin = c.testStdin - } - - dec := json.NewDecoder(stdin) - if err := dec.Decode(&result); err != nil { - return nil, fmt.Errorf( - "Error loading data at index %d: %s", i, err) - } - - continue - } - - // If the arg begins with "@" then we read the file directly. - if arg[0] == '@' { - f, err := os.Open(arg[1:]) - if err != nil { - return nil, fmt.Errorf( - "Error loading data at index %d: %s", i, err) - } - - dec := json.NewDecoder(f) - err = dec.Decode(&result) - f.Close() - if err != nil { - return nil, fmt.Errorf( - "Error loading data at index %d: %s", i, err) - } - - continue - } - - // Split into key/value - parts := strings.SplitN(arg, "=", 2) - if len(parts) != 2 { - return nil, fmt.Errorf( - "Data at index %d is not in key=value format: %s", - i, arg) - } - key, value := parts[0], parts[1] - - if value[0] == '@' { - contents, err := ioutil.ReadFile(value[1:]) - if err != nil { - return nil, fmt.Errorf( - "Error reading file value for index %d: %s", i, err) - } - - value = string(contents) - } else if value[0] == '\\' && value[1] == '@' { - value = value[1:] - } - - result[key] = value + var stdin io.Reader = os.Stdin + if c.testStdin != nil { + stdin = c.testStdin } - return result, nil + builder := &kvbuilder.Builder{Stdin: stdin} + if err := builder.Add(args...); err != nil { + return nil, err + } + + return builder.Map(), nil } func (c *WriteCommand) Synopsis() string { diff --git a/helper/kv-builder/builder.go b/helper/kv-builder/builder.go new file mode 100644 index 0000000000..767cff4b1c --- /dev/null +++ b/helper/kv-builder/builder.go @@ -0,0 +1,114 @@ +package kvbuilder + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "os" + "strings" +) + +// Builder is a struct to build a key/value mapping based on a list +// of "k=v" pairs, where the value might come from stdin, a file, etc. +type Builder struct { + Stdin io.Reader + + result map[string]interface{} + stdin bool +} + +// Map returns the built map. +func (b *Builder) Map() map[string]interface{} { + return b.result +} + +// Add adds to the mapping with the given args. +func (b *Builder) Add(args ...string) error { + for _, a := range args { + if err := b.add(a); err != nil { + return fmt.Errorf("Invalid key/value pair '%s': %s", a, err) + } + } + + return nil +} + +func (b *Builder) add(raw string) error { + // Regardless of validity, make sure we make our result + if b.result == nil { + b.result = make(map[string]interface{}) + } + + // Empty strings are fine, just ignored + if raw == "" { + return nil + } + + // If the arg is exactly "-", then we need to read from stdin + // and merge the results into the resulting structure. + if raw == "-" { + if b.Stdin == nil { + return fmt.Errorf("stdin is not supported") + } + if b.stdin { + return fmt.Errorf("stdin already consumed") + } + + b.stdin = true + return b.addReader(b.Stdin) + } + + // If the arg begins with "@" then we need to read a file directly + if raw[0] == '@' { + f, err := os.Open(raw[1:]) + if err != nil { + return err + } + defer f.Close() + + return b.addReader(f) + } + + // Split into key/value + parts := strings.SplitN(raw, "=", 2) + if len(parts) != 2 { + return fmt.Errorf("format must be key=value") + } + key, value := parts[0], parts[1] + + if len(value) > 0 && value[0] == '@' { + contents, err := ioutil.ReadFile(value[1:]) + if err != nil { + return fmt.Errorf("error reading file: %s", err) + } + + value = string(contents) + } else if value[0] == '\\' && value[1] == '@' { + value = value[1:] + } else if value == "-" { + if b.Stdin == nil { + return fmt.Errorf("stdin is not supported") + } + if b.stdin { + return fmt.Errorf("stdin already consumed") + } + b.stdin = true + + var buf bytes.Buffer + if _, err := io.Copy(&buf, b.Stdin); err != nil { + return err + } + + value = buf.String() + } + + b.result[key] = value + return nil +} + +func (b *Builder) addReader(r io.Reader) error { + dec := json.NewDecoder(r) + return dec.Decode(&b.result) +} diff --git a/helper/kv-builder/builder_test.go b/helper/kv-builder/builder_test.go new file mode 100644 index 0000000000..90e892969a --- /dev/null +++ b/helper/kv-builder/builder_test.go @@ -0,0 +1,86 @@ +package kvbuilder + +import ( + "bytes" + "reflect" + "testing" +) + +func TestBuilder_basic(t *testing.T) { + var b Builder + err := b.Add("foo=bar", "bar=baz") + if err != nil { + t.Fatalf("err: %s", err) + } + + expected := map[string]interface{}{ + "foo": "bar", + "bar": "baz", + } + actual := b.Map() + if !reflect.DeepEqual(actual, expected) { + t.Fatalf("bad: %#v", actual) + } +} + +func TestBuilder_escapedAt(t *testing.T) { + var b Builder + err := b.Add("foo=bar", "bar=\\@baz") + if err != nil { + t.Fatalf("err: %s", err) + } + + expected := map[string]interface{}{ + "foo": "bar", + "bar": "@baz", + } + actual := b.Map() + if !reflect.DeepEqual(actual, expected) { + t.Fatalf("bad: %#v", actual) + } +} + +func TestBuilder_stdin(t *testing.T) { + var b Builder + b.Stdin = bytes.NewBufferString("baz") + err := b.Add("foo=bar", "bar=-") + if err != nil { + t.Fatalf("err: %s", err) + } + + expected := map[string]interface{}{ + "foo": "bar", + "bar": "baz", + } + actual := b.Map() + if !reflect.DeepEqual(actual, expected) { + t.Fatalf("bad: %#v", actual) + } +} + +func TestBuilder_stdinMap(t *testing.T) { + var b Builder + b.Stdin = bytes.NewBufferString(`{"foo": "bar"}`) + err := b.Add("-", "bar=baz") + if err != nil { + t.Fatalf("err: %s", err) + } + + expected := map[string]interface{}{ + "foo": "bar", + "bar": "baz", + } + actual := b.Map() + if !reflect.DeepEqual(actual, expected) { + t.Fatalf("bad: %#v", actual) + } +} + +func TestBuilder_stdinTwice(t *testing.T) { + var b Builder + b.Stdin = bytes.NewBufferString(`{"foo": "bar"}`) + err := b.Add("-", "-") + if err == nil { + t.Fatal("should error") + } +}