Adding transit logical backend

This commit is contained in:
Armon Dadgar
2015-04-15 17:08:12 -07:00
parent 59073cf775
commit d02028a0e4
6 changed files with 515 additions and 2 deletions

View File

@@ -0,0 +1,35 @@
package transit
import (
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
func Factory(map[string]string) (logical.Backend, error) {
return Backend(), nil
}
func Backend() *framework.Backend {
var b backend
b.Backend = &framework.Backend{
PathsSpecial: &logical.Paths{
Root: []string{
"policy/*",
},
},
Paths: []*framework.Path{
pathPolicy(),
pathEncrypt(),
pathDecrypt(),
},
Secrets: []*framework.Secret{},
}
return b.Backend
}
type backend struct {
*framework.Backend
}

View File

@@ -0,0 +1,132 @@
package transit
import (
"encoding/base64"
"fmt"
"testing"
"github.com/hashicorp/vault/logical"
logicaltest "github.com/hashicorp/vault/logical/testing"
"github.com/mitchellh/mapstructure"
)
const (
testPlaintext = "the quick brown fox"
)
func TestBackend_basic(t *testing.T) {
decryptData := make(map[string]interface{})
logicaltest.Test(t, logicaltest.TestCase{
Backend: Backend(),
Steps: []logicaltest.TestStep{
testAccStepWritePolicy(t, "test"),
testAccStepReadPolicy(t, "test", false),
testAccStepEncrypt(t, "test", testPlaintext, decryptData),
testAccStepDecrypt(t, "test", testPlaintext, decryptData),
testAccStepDeletePolicy(t, "test"),
testAccStepReadPolicy(t, "test", true),
},
})
}
func testAccStepWritePolicy(t *testing.T, name string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.WriteOperation,
Path: "policy/" + name,
}
}
func testAccStepDeletePolicy(t *testing.T, name string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.DeleteOperation,
Path: "policy/" + name,
}
}
func testAccStepReadPolicy(t *testing.T, name string, expectNone bool) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.ReadOperation,
Path: "policy/" + name,
Check: func(resp *logical.Response) error {
if resp == nil && !expectNone {
return fmt.Errorf("missing response")
} else if expectNone {
if resp != nil {
return fmt.Errorf("response when expecting none")
}
return nil
}
var d struct {
Name string `mapstructure:"name"`
Key []byte `mapstructure:"key"`
CipherMode string `mapstructure:"cipher_mode"`
}
if err := mapstructure.Decode(resp.Data, &d); err != nil {
return err
}
if d.Name != name {
return fmt.Errorf("bad: %#v", d)
}
if d.CipherMode != "aes-gcm" {
return fmt.Errorf("bad: %#v", d)
}
if len(d.Key) != 32 {
return fmt.Errorf("bad: %#v", d)
}
return nil
},
}
}
func testAccStepEncrypt(
t *testing.T, name, plaintext string, decryptData map[string]interface{}) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.WriteOperation,
Path: "encrypt/" + name,
Data: map[string]interface{}{
"plaintext": base64.StdEncoding.EncodeToString([]byte(plaintext)),
},
Check: func(resp *logical.Response) error {
var d struct {
Ciphertext string `mapstructure:"ciphertext"`
}
if err := mapstructure.Decode(resp.Data, &d); err != nil {
return err
}
if d.Ciphertext == "" {
return fmt.Errorf("missing ciphertext")
}
decryptData["ciphertext"] = d.Ciphertext
return nil
},
}
}
func testAccStepDecrypt(
t *testing.T, name, plaintext string, decryptData map[string]interface{}) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.WriteOperation,
Path: "decrypt/" + name,
Data: decryptData,
Check: func(resp *logical.Response) error {
var d struct {
Plaintext string `mapstructure:"plaintext"`
}
if err := mapstructure.Decode(resp.Data, &d); err != nil {
return err
}
// Decode the base64
plainRaw, err := base64.StdEncoding.DecodeString(d.Plaintext)
if err != nil {
return err
}
if string(plainRaw) != plaintext {
return fmt.Errorf("plaintext mismatch: %s expect: %s", plainRaw, plaintext)
}
return nil
},
}
}

View File

@@ -0,0 +1,100 @@
package transit
import (
"crypto/aes"
"crypto/cipher"
"encoding/base64"
"strings"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
func pathDecrypt() *framework.Path {
return &framework.Path{
Pattern: `decrypt/(?P<name>\w+)`,
Fields: map[string]*framework.FieldSchema{
"name": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Name of the policy",
},
"ciphertext": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Ciphertext value to decrypt",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.WriteOperation: pathDecryptWrite,
},
}
}
func pathDecryptWrite(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)
value := d.Get("ciphertext").(string)
if len(value) == 0 {
return logical.ErrorResponse("missing ciphertext to decrypt"), logical.ErrInvalidRequest
}
// Get the policy
p, err := getPolicy(req, name)
if err != nil {
return nil, err
}
// Error if invalid policy
if p == nil {
return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest
}
// Guard against a potentially invalid cipher-mode
switch p.CipherMode {
case "aes-gcm":
default:
return logical.ErrorResponse("unsupported cipher mode"), logical.ErrInvalidRequest
}
// Verify the prefix
if !strings.HasPrefix(value, "vault:v0:") {
return logical.ErrorResponse("invalid ciphertext"), logical.ErrInvalidRequest
}
// Decode the base64
decoded, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(value, "vault:v0:"))
if err != nil {
return logical.ErrorResponse("invalid ciphertext"), logical.ErrInvalidRequest
}
// Setup the cipher
aesCipher, err := aes.NewCipher(p.Key)
if err != nil {
return nil, err
}
// Setup the GCM AEAD
gcm, err := cipher.NewGCM(aesCipher)
if err != nil {
return nil, err
}
// Extract the nonce and ciphertext
nonce := decoded[:gcm.NonceSize()]
ciphertext := decoded[gcm.NonceSize():]
// Verify and Decrypt
plain, err := gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return logical.ErrorResponse("invalid ciphertext"), logical.ErrInvalidRequest
}
// Generate the response
resp := &logical.Response{
Data: map[string]interface{}{
"plaintext": base64.StdEncoding.EncodeToString(plain),
},
}
return resp, nil
}

View File

@@ -0,0 +1,104 @@
package transit
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
func pathEncrypt() *framework.Path {
return &framework.Path{
Pattern: `encrypt/(?P<name>\w+)`,
Fields: map[string]*framework.FieldSchema{
"name": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Name of the policy",
},
"plaintext": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Plaintext value to encrypt",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.WriteOperation: pathEncryptWrite,
},
}
}
func pathEncryptWrite(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)
value := d.Get("plaintext").(string)
if len(value) == 0 {
return logical.ErrorResponse("missing plaintext to encrypt"), logical.ErrInvalidRequest
}
// Decode the plaintext value
plaintext, err := base64.StdEncoding.DecodeString(value)
if err != nil {
return logical.ErrorResponse("failed to decode plaintext as base64"), logical.ErrInvalidRequest
}
// Get the policy
p, err := getPolicy(req, name)
if err != nil {
return nil, err
}
// Error if invalid policy
if p == nil {
return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest
}
// Guard against a potentially invalid cipher-mode
switch p.CipherMode {
case "aes-gcm":
default:
return logical.ErrorResponse("unsupported cipher mode"), logical.ErrInvalidRequest
}
// Setup the cipher
aesCipher, err := aes.NewCipher(p.Key)
if err != nil {
return nil, err
}
// Setup the GCM AEAD
gcm, err := cipher.NewGCM(aesCipher)
if err != nil {
return nil, err
}
// Compute random nonce
nonce := make([]byte, gcm.NonceSize())
_, err = rand.Read(nonce)
if err != nil {
return nil, err
}
// Encrypt and tag with GCM
out := gcm.Seal(nil, nonce, plaintext, nil)
// Place the encrypted data after the nonce
full := append(nonce, out...)
// Convert to base64
encoded := base64.StdEncoding.EncodeToString(full)
// Prepend some information
encoded = "vault:v0:" + encoded
// Generate the response
resp := &logical.Response{
Data: map[string]interface{}{
"ciphertext": encoded,
},
}
return resp, nil
}

View File

@@ -0,0 +1,140 @@
package transit
import (
"crypto/rand"
"encoding/json"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
// Policy is the struct used to store metadata
type Policy struct {
Name string `json:"name"`
Key []byte `json:"key"`
CipherMode string `json:"cipher"`
}
func (p *Policy) Serialize() ([]byte, error) {
return json.Marshal(p)
}
func DeserializePolicy(buf []byte) (*Policy, error) {
p := new(Policy)
if err := json.Unmarshal(buf, p); err != nil {
return nil, err
}
return p, nil
}
func getPolicy(req *logical.Request, name string) (*Policy, error) {
// Check if the policy already exists
raw, err := req.Storage.Get("policy/" + name)
if err != nil {
return nil, err
}
if raw == nil {
return nil, nil
}
// Decode the policy
p, err := DeserializePolicy(raw.Value)
if err != nil {
return nil, err
}
return p, nil
}
func pathPolicy() *framework.Path {
return &framework.Path{
Pattern: `policy/(?P<name>\w+)`,
Fields: map[string]*framework.FieldSchema{
"name": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Name of the policy",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.WriteOperation: pathPolicyWrite,
logical.DeleteOperation: pathPolicyDelete,
logical.ReadOperation: pathPolicyRead,
},
}
}
func pathPolicyWrite(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)
// Check if the policy already exists
existing, err := getPolicy(req, name)
if err != nil {
return nil, err
}
if existing != nil {
return nil, nil
}
// Create the policy object
p := &Policy{
Name: name,
CipherMode: "aes-gcm",
}
// Generate a 256bit key
p.Key = make([]byte, 32)
_, err = rand.Read(p.Key)
if err != nil {
return nil, err
}
// Encode the policy
buf, err := p.Serialize()
if err != nil {
return nil, err
}
// Write the policy into storage
err = req.Storage.Put(&logical.StorageEntry{
Key: "policy/" + name,
Value: buf,
})
if err != nil {
return nil, err
}
return nil, nil
}
func pathPolicyRead(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)
p, err := getPolicy(req, name)
if err != nil {
return nil, err
}
if p == nil {
return nil, nil
}
// Return the response
resp := &logical.Response{
Data: map[string]interface{}{
"name": p.Name,
"key": p.Key,
"cipher_mode": p.CipherMode,
},
}
return resp, nil
}
func pathPolicyDelete(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)
err := req.Storage.Delete("policy/" + name)
if err != nil {
return nil, err
}
return nil, nil
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/hashicorp/vault/builtin/logical/aws"
"github.com/hashicorp/vault/builtin/logical/consul"
"github.com/hashicorp/vault/builtin/logical/transit"
"github.com/hashicorp/vault/audit"
tokenDisk "github.com/hashicorp/vault/builtin/token/disk"
@@ -51,8 +52,9 @@ func Commands(metaPtr *command.Meta) map[string]cli.CommandFactory {
"github": credGitHub.Factory,
},
LogicalBackends: map[string]logical.Factory{
"aws": aws.Factory,
"consul": consul.Factory,
"aws": aws.Factory,
"consul": consul.Factory,
"transit": transit.Factory,
},
}, nil
},