diff --git a/command/generate-root.go b/command/generate-root.go index f9d3222427..2d85424f1c 100644 --- a/command/generate-root.go +++ b/command/generate-root.go @@ -83,14 +83,25 @@ func (c *GenerateRootCommand) Run(args []string) int { // If we are initing, or if we are not started but are not running a // special function, check otp and pgpkey - if init || - (!init && !cancel && !status && !genotp && len(decode) == 0 && !rootGenerationStatus.Started) { + checkOtpPgp := false + switch { + case init: + checkOtpPgp = true + case cancel: + case status: + case genotp: + case len(decode) != 0: + case rootGenerationStatus.Started: + default: + checkOtpPgp = true + } + if checkOtpPgp { switch { case len(otp) == 0 && (pgpKeyArr == nil || len(pgpKeyArr) == 0): - c.Ui.Error("-otp or -pgp-key must be specified") + c.Ui.Error(c.Help()) return 1 case len(otp) != 0 && pgpKeyArr != nil && len(pgpKeyArr) != 0: - c.Ui.Error("Only one of -otp or -pgp-key must be specified") + c.Ui.Error(c.Help()) return 1 case len(otp) != 0: err := c.verifyOTP(otp) diff --git a/helper/xor/xor.go b/helper/xor/xor.go index c83ea31ac5..4c5f88c537 100644 --- a/helper/xor/xor.go +++ b/helper/xor/xor.go @@ -1,35 +1,16 @@ package xor import ( - "crypto/rand" "encoding/base64" "fmt" ) -func GenerateRandBytes(length int) ([]byte, error) { - if length < 0 { - return nil, fmt.Errorf("length must be >= 0") - } - - buf := make([]byte, length) - if length == 0 { - return buf, nil - } - - n, err := rand.Read(buf) - if err != nil { - return nil, err - } - if n != length { - return nil, fmt.Errorf("unable to read %d bytes; only read %d", length, n) - } - - return buf, nil -} - -func XORBuffers(a, b []byte) ([]byte, error) { +// XORBytes takes two byte slices and XORs them together, returning the final +// byte slice. It is an error to pass in two byte slices that do not have the +// same length. +func XORBytes(a, b []byte) ([]byte, error) { if len(a) != len(b) { - return nil, fmt.Errorf("length of buffers is not equivalent: %d != %d", len(a), len(b)) + return nil, fmt.Errorf("length of byte slices is not equivalent: %d != %d", len(a), len(b)) } buf := make([]byte, len(a)) @@ -41,6 +22,9 @@ func XORBuffers(a, b []byte) ([]byte, error) { return buf, nil } +// XORBase64 takes two base64-encoded strings and XORs the decoded byte slices +// together, returning the final byte slice. It is an error to pass in two +// strings that do not have the same length to their base64-decoded byte slice. func XORBase64(a, b string) ([]byte, error) { aBytes, err := base64.StdEncoding.DecodeString(a) if err != nil { @@ -58,14 +42,5 @@ func XORBase64(a, b string) ([]byte, error) { return nil, fmt.Errorf("decoded second base64 value is nil or empty") } - if len(aBytes) != len(bBytes) { - return nil, fmt.Errorf("decoded values are not same length: %d != %d", len(aBytes), len(bBytes)) - } - - buf := make([]byte, len(aBytes)) - for i, _ := range aBytes { - buf[i] = aBytes[i] ^ bBytes[i] - } - - return buf, nil + return XORBytes(aBytes, bBytes) } diff --git a/helper/xor/xor_test.go b/helper/xor/xor_test.go new file mode 100644 index 0000000000..2139d9166c --- /dev/null +++ b/helper/xor/xor_test.go @@ -0,0 +1,45 @@ +package xor + +import ( + "crypto/rand" + "encoding/base64" + "fmt" + "testing" +) + +const ( + tokenB64 = "ZGE0N2JiODkzYjhkMDYxYw==" + xorB64 = "iGiQYG9L0nIp+jRL5+Zk2w==" + expectedB64 = "7AmkVw0p6ksamAwv19BVuA==" +) + +func GenerateRandBytes(length int) ([]byte, error) { + if length < 0 { + return nil, fmt.Errorf("length must be >= 0") + } + + buf := make([]byte, length) + if length == 0 { + return buf, nil + } + + n, err := rand.Read(buf) + if err != nil { + return nil, err + } + if n != length { + return nil, fmt.Errorf("unable to read %d bytes; only read %d", length, n) + } + + return buf, nil +} + +func TestBase64XOR(t *testing.T) { + ret, err := XORBase64(tokenB64, xorB64) + if err != nil { + t.Fatal(err) + } + if res := base64.StdEncoding.EncodeToString(ret); res != expectedB64 { + t.Fatalf("bad: %s", res) + } +}